[Mlir-commits] [mlir] [MLIR][XeGPU][VectorToXeGPU] Lower vector.load/store/transfer_read/transfer_write to new offsets syntax (PR #162095)
Dmitry Chigarev
llvmlistbot at llvm.org
Thu Oct 30 07:33:08 PDT 2025
https://github.com/dchigarev updated https://github.com/llvm/llvm-project/pull/162095
>From ef520d09280a43cbc86cf760b24ab670508a0df8 Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Mon, 6 Oct 2025 14:37:21 +0000
Subject: [PATCH 01/13] [MLIR][XeGPU][VectorToXeGPU] Lower
vector.load/store/transfer_read/transfer_write to new offsets syntax
Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
.../VectorToXeGPU/VectorToXeGPU.cpp | 218 ++++++++++++------
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 12 +-
.../VectorToXeGPU/load-to-xegpu.mlir | 4 +-
.../VectorToXeGPU/store-to-xegpu.mlir | 4 +-
.../VectorToXeGPU/transfer-read-to-xegpu.mlir | 8 +-
.../transfer-write-to-xegpu.mlir | 4 +-
6 files changed, 171 insertions(+), 79 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index e2c7d803e5a5e..41526a7e34971 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -97,6 +97,64 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
return success();
}
+static void computeMixedShapesStrides(PatternRewriter &rewriter, Location loc,
+ SmallVector<OpFoldResult> &mixedShapes,
+ SmallVector<OpFoldResult> &mixedStrides,
+ SmallVector<int64_t> &strides,
+ TypedValue<MemRefType> src) {
+ auto srcTy = src.getType();
+ // In case of any dynamic shapes, source's shape and strides have to be
+ // explicitly provided.
+ SmallVector<Value> sourceDims;
+ unsigned srcRank = srcTy.getRank();
+ for (unsigned i = 0; i < srcRank; ++i)
+ sourceDims.push_back(memref::DimOp::create(rewriter, loc, src, i));
+
+ for (auto [idx, shape] : llvm::enumerate(srcTy.getShape())) {
+ if (shape == ShapedType::kDynamic)
+ mixedShapes.push_back(sourceDims[idx]);
+ else
+ mixedShapes.push_back(rewriter.getI64IntegerAttr(shape));
+ }
+
+ // Compute strides in reverse order.
+ Value accStride = arith::ConstantIndexOp::create(rewriter, loc, 1);
+ // Last stride is guaranteed to be static and unit.
+ mixedStrides.push_back(rewriter.getI64IntegerAttr(1));
+ for (int i = static_cast<int>(strides.size()) - 2; i >= 0; --i) {
+ accStride =
+ arith::MulIOp::create(rewriter, loc, accStride, sourceDims[i + 1]);
+ if (strides[i] == ShapedType::kDynamic)
+ mixedStrides.push_back(accStride);
+ else
+ mixedStrides.push_back(rewriter.getI64IntegerAttr(strides[i]));
+ }
+ std::reverse(mixedStrides.begin(), mixedStrides.end());
+}
+
+static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter,
+ Location loc,
+ xegpu::TensorDescType descType,
+ TypedValue<MemRefType> src) {
+ MemRefType srcTy = src.getType();
+ auto [strides, offset] = srcTy.getStridesAndOffset();
+
+ xegpu::CreateNdDescOp ndDesc;
+ if (srcTy.hasStaticShape())
+ ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src);
+ else {
+ SmallVector<OpFoldResult> mixedShapes;
+ SmallVector<OpFoldResult> mixedStrides;
+ computeMixedShapesStrides(rewriter, loc, mixedShapes, mixedStrides, strides,
+ src);
+
+ ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
+ mixedShapes, mixedStrides);
+ }
+
+ return ndDesc;
+}
+
static xegpu::CreateNdDescOp
createNdDescriptor(PatternRewriter &rewriter, Location loc,
xegpu::TensorDescType descType, TypedValue<MemRefType> src,
@@ -109,45 +167,22 @@ createNdDescriptor(PatternRewriter &rewriter, Location loc,
ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
getAsOpFoldResult(offsets));
} else {
- // In case of any dynamic shapes, source's shape and strides have to be
- // explicitly provided.
- SmallVector<Value> sourceDims;
- unsigned srcRank = srcTy.getRank();
- for (unsigned i = 0; i < srcRank; ++i)
- sourceDims.push_back(memref::DimOp::create(rewriter, loc, src, i));
-
- SmallVector<int64_t> constOffsets;
- SmallVector<Value> dynOffsets;
+ SmallVector<OpFoldResult> mixedOffsets;
for (Value offset : offsets) {
std::optional<int64_t> staticVal = getConstantIntValue(offset);
- if (!staticVal)
- dynOffsets.push_back(offset);
- constOffsets.push_back(staticVal.value_or(ShapedType::kDynamic));
- }
-
- SmallVector<Value> dynShapes;
- for (auto [idx, shape] : llvm::enumerate(srcTy.getShape())) {
- if (shape == ShapedType::kDynamic)
- dynShapes.push_back(sourceDims[idx]);
+ if (staticVal)
+ mixedOffsets.push_back(rewriter.getI64IntegerAttr(staticVal.value()));
+ else
+ mixedOffsets.push_back(offset);
}
- // Compute strides in reverse order.
- SmallVector<Value> dynStrides;
- Value accStride = arith::ConstantIndexOp::create(rewriter, loc, 1);
- // Last stride is guaranteed to be static and unit.
- for (int i = static_cast<int>(strides.size()) - 2; i >= 0; --i) {
- accStride =
- arith::MulIOp::create(rewriter, loc, accStride, sourceDims[i + 1]);
- if (strides[i] == ShapedType::kDynamic)
- dynStrides.push_back(accStride);
- }
- std::reverse(dynStrides.begin(), dynStrides.end());
+ SmallVector<OpFoldResult> mixedShapes;
+ SmallVector<OpFoldResult> mixedStrides;
+ computeMixedShapesStrides(rewriter, loc, mixedShapes, mixedStrides, strides,
+ src);
ndDesc = xegpu::CreateNdDescOp::create(
- rewriter, loc, descType, src, dynOffsets, dynShapes, dynStrides,
- DenseI64ArrayAttr::get(rewriter.getContext(), constOffsets),
- DenseI64ArrayAttr::get(rewriter.getContext(), srcTy.getShape()),
- DenseI64ArrayAttr::get(rewriter.getContext(), strides));
+ rewriter, loc, descType, src, mixedOffsets, mixedShapes, mixedStrides);
}
return ndDesc;
@@ -523,21 +558,35 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
descShape, elementType, /*array_length=*/1,
/*boundary_check=*/isOutOfBounds, xegpu::MemorySpace::Global);
- xegpu::CreateNdDescOp ndDesc =
- createNdDescriptor(rewriter, loc, descType,
- dyn_cast<TypedValue<MemRefType>>(readOp.getBase()),
- readOp.getIndices());
-
DenseI64ArrayAttr transposeAttr =
!isTransposeLoad ? nullptr
: DenseI64ArrayAttr::get(rewriter.getContext(),
ArrayRef<int64_t>{1, 0});
// By default, no specific caching policy is assigned.
xegpu::CachePolicyAttr hint = nullptr;
- auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
- /*packed=*/nullptr, transposeAttr,
- /*l1_hint=*/hint,
- /*l2_hint=*/hint, /*l3_hint=*/hint);
+ xegpu::LoadNdOp loadOp;
+
+ if (vecTy.getRank() == readOp.getBase().getType().getRank()) {
+ xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
+ rewriter, loc, descType,
+ dyn_cast<TypedValue<MemRefType>>(readOp.getBase()));
+
+ loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
+ getAsOpFoldResult(readOp.getIndices()),
+ /*packed=*/nullptr, transposeAttr,
+ /*l1_hint=*/hint,
+ /*l2_hint=*/hint, /*l3_hint=*/hint);
+ } else {
+ xegpu::CreateNdDescOp ndDesc =
+ createNdDescriptor(rewriter, loc, descType,
+ dyn_cast<TypedValue<MemRefType>>(readOp.getBase()),
+ readOp.getIndices());
+
+ loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
+ /*packed=*/nullptr, transposeAttr,
+ /*l1_hint=*/hint,
+ /*l2_hint=*/hint, /*l3_hint=*/hint);
+ }
rewriter.replaceOp(readOp, loadOp);
return success();
@@ -579,17 +628,30 @@ struct TransferWriteLowering
vecTy.getShape(), vecTy.getElementType(),
/*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(),
xegpu::MemorySpace::Global);
- xegpu::CreateNdDescOp ndDesc =
- createNdDescriptor(rewriter, loc, descType,
- dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()),
- writeOp.getIndices());
-
// By default, no specific caching policy is assigned.
xegpu::CachePolicyAttr hint = nullptr;
- auto storeOp =
- xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc,
- /*l1_hint=*/hint,
- /*l2_hint=*/hint, /*l3_hint=*/hint);
+ xegpu::StoreNdOp storeOp;
+ if (vecTy.getRank() == writeOp.getBase().getType().getRank()) {
+ xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
+ rewriter, loc, descType,
+ dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()));
+
+ storeOp =
+ xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc,
+ getAsOpFoldResult(writeOp.getIndices()),
+ /*l1_hint=*/hint,
+ /*l2_hint=*/hint, /*l3_hint=*/hint);
+ } else {
+ xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
+ rewriter, loc, descType,
+ dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()),
+ writeOp.getIndices());
+
+ storeOp =
+ xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc,
+ /*l1_hint=*/hint,
+ /*l2_hint=*/hint, /*l3_hint=*/hint);
+ }
rewriter.replaceOp(writeOp, storeOp);
return success();
@@ -674,19 +736,32 @@ struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
// Boundary check is available only for block instructions.
bool boundaryCheck = vecTy.getRank() > 1;
+ // By default, no specific caching policy is assigned.
+ xegpu::CachePolicyAttr hint = nullptr;
auto descType = xegpu::TensorDescType::get(
vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1,
boundaryCheck, xegpu::MemorySpace::Global);
- xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
- rewriter, loc, descType, loadOp.getBase(), loadOp.getIndices());
- // By default, no specific caching policy is assigned.
- xegpu::CachePolicyAttr hint = nullptr;
- auto loadNdOp = xegpu::LoadNdOp::create(
- rewriter, loc, vecTy, ndDesc, /*packed=*/nullptr, /*transpose=*/nullptr,
- /*l1_hint=*/hint,
- /*l2_hint=*/hint, /*l3_hint=*/hint);
+ xegpu::LoadNdOp loadNdOp;
+
+ if (vecTy.getRank() == loadOp.getBase().getType().getRank()) {
+ xegpu::CreateNdDescOp ndDesc =
+ createNdDescriptor(rewriter, loc, descType, loadOp.getBase());
+ loadNdOp = xegpu::LoadNdOp::create(
+ rewriter, loc, vecTy, ndDesc, getAsOpFoldResult(loadOp.getIndices()),
+ /*packed=*/nullptr, /*transpose=*/nullptr,
+ /*l1_hint=*/hint,
+ /*l2_hint=*/hint, /*l3_hint=*/hint);
+ } else {
+ xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
+ rewriter, loc, descType, loadOp.getBase(), loadOp.getIndices());
+ loadNdOp =
+ xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
+ /*packed=*/nullptr, /*transpose=*/nullptr,
+ /*l1_hint=*/hint,
+ /*l2_hint=*/hint, /*l3_hint=*/hint);
+ }
rewriter.replaceOp(loadOp, loadNdOp);
return success();
@@ -711,15 +786,28 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
auto descType = xegpu::TensorDescType::get(
vecTy.getShape(), vecTy.getElementType(),
/*array_length=*/1, boundaryCheck, xegpu::MemorySpace::Global);
- xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
- rewriter, loc, descType, storeOp.getBase(), storeOp.getIndices());
// By default, no specific caching policy is assigned.
xegpu::CachePolicyAttr hint = nullptr;
- auto storeNdOp =
- xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc,
- /*l1_hint=*/hint,
- /*l2_hint=*/hint, /*l3_hint=*/hint);
+ xegpu::StoreNdOp storeNdOp;
+ if (vecTy.getRank() == storeOp.getBase().getType().getRank()) {
+ xegpu::CreateNdDescOp ndDesc =
+ createNdDescriptor(rewriter, loc, descType, storeOp.getBase());
+
+ storeNdOp =
+ xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc,
+ getAsOpFoldResult(storeOp.getIndices()),
+ /*l1_hint=*/hint,
+ /*l2_hint=*/hint, /*l3_hint=*/hint);
+ } else {
+ xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
+ rewriter, loc, descType, storeOp.getBase(), storeOp.getIndices());
+
+ storeNdOp = xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc,
+ /*l1_hint=*/hint,
+ /*l2_hint=*/hint, /*l3_hint=*/hint);
+ }
+
rewriter.replaceOp(storeOp, storeNdOp);
return success();
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index abd12e2e69ac0..8ed8b26dd2a0e 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -258,8 +258,10 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
// if shape and strides are from Memref, we don't need attributes for them
- // to keep the IR print clean.
- if (staticShape == memrefShape && staticStrides == memrefStrides) {
+ // to keep the IR print clean (only do so for full-static case, otherwise
+ // printer would fail trying to print empty array-attr).
+ if (staticShape == memrefShape && staticStrides == memrefStrides &&
+ dynamicShape.empty() && dynamicStrides.empty()) {
staticShapeAttr = DenseI64ArrayAttr();
staticStridesAttr = DenseI64ArrayAttr();
}
@@ -320,8 +322,10 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
// if shape and strides are from Memref, we don't need attributes for them
- // to keep the IR print clean.
- if (staticShape == memrefShape && staticStrides == memrefStrides) {
+ // to keep the IR print clean (only do so for full-static case, otherwise
+ // printer would fail trying to print empty array-attr).
+ if (staticShape == memrefShape && staticStrides == memrefStrides &&
+ dynamicShape.empty() && dynamicStrides.empty()) {
staticShapeAttr = DenseI64ArrayAttr();
staticStridesAttr = DenseI64ArrayAttr();
}
diff --git a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
index 9908205f07c92..c7c0485768b99 100644
--- a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
@@ -72,9 +72,9 @@ func.func @load_out_of_bounds(%source: memref<7x15xf32>,
// CHECK-SAME: %[[SRC:.+]]: memref<7x15xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME: %[[SRC]]
// CHECK-SAME: memref<7x15xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
// CHECK: return %[[VEC]]
// -----
diff --git a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
index 2c498dcc2a071..19240abe1e75c 100644
--- a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
@@ -74,9 +74,9 @@ func.func @store_out_of_bounds(%vec: vector<8x16xf32>,
// CHECK-SAME: %[[SRC:.+]]: memref<7x64xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME: %[[SRC]]
// CHECK-SAME: memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
// -----
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
index c4ca79af1bd9a..72bdab0a4db3a 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
@@ -83,9 +83,9 @@ gpu.func @load_zero_pad_out_of_bounds(%source: memref<32x64xf32>,
// LOAD-ND-LABEL: @load_zero_pad_out_of_bounds(
// LOAD-ND-SAME: %[[SRC:.+]]: memref<32x64xf32>,
// LOAD-ND-SAME: %[[OFFSET:.+]]: index
-// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
+// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
// LOAD-ND-SAME: memref<32x64xf32> -> !xegpu.tensor_desc<8x16xf32>
-// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
// LOAD-ND: return %[[VEC]]
// LOAD-GATHER-LABEL: @load_zero_pad_out_of_bounds(
@@ -109,9 +109,9 @@ gpu.func @load_transposed(%source: memref<32x64xf32>,
// LOAD-ND-SAME: %[[SRC:.+]]: memref<32x64xf32>,
// LOAD-ND-SAME: %[[OFFSET1:.+]]: index,
// LOAD-ND-SAME: %[[OFFSET2:.+]]: index
-// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET1]], %[[OFFSET2]]]
+// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
// LOAD-ND-SAME: memref<32x64xf32> -> !xegpu.tensor_desc<16x8xf32
-// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]] <{transpose = array<i64: 1, 0>}>
+// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET1]], %[[OFFSET2]]] <{transpose = array<i64: 1, 0>}>
// LOAD-ND-SAME: -> vector<8x16xf32>
// LOAD-ND: return %[[VEC]]
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
index fcfc9414da4f6..ca3bbc11a5180 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
@@ -126,9 +126,9 @@ gpu.func @store_out_of_bounds(%vec: vector<8x16xf32>,
// STORE-ND-SAME: %[[SRC:.+]]: memref<7x64xf32>,
// STORE-ND-SAME: %[[OFFSET:.+]]: index
// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// STORE-ND-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
+// STORE-ND-SAME: %[[SRC]]
// STORE-ND-SAME: memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32>
-// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
// STORE-SCATTER-LABEL: @store_out_of_bounds(
// STORE-SCATTER: vector.transfer_write
>From 46af25a0944431b69084a640dd399082d398e3c9 Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Thu, 16 Oct 2025 10:45:14 +0000
Subject: [PATCH 02/13] Relax len(offsets) == tdescRank requirement
Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
.../VectorToXeGPU/VectorToXeGPU.cpp | 210 +++++-------------
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 96 ++++++--
.../VectorToXeGPU/load-to-xegpu.mlir | 12 +-
.../VectorToXeGPU/store-to-xegpu.mlir | 12 +-
.../VectorToXeGPU/transfer-read-to-xegpu.mlir | 20 +-
.../transfer-write-to-xegpu.mlir | 16 +-
mlir/test/Dialect/XeGPU/invalid.mlir | 14 +-
7 files changed, 172 insertions(+), 208 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 41526a7e34971..f3dcb31f6b0be 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -97,41 +97,6 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
return success();
}
-static void computeMixedShapesStrides(PatternRewriter &rewriter, Location loc,
- SmallVector<OpFoldResult> &mixedShapes,
- SmallVector<OpFoldResult> &mixedStrides,
- SmallVector<int64_t> &strides,
- TypedValue<MemRefType> src) {
- auto srcTy = src.getType();
- // In case of any dynamic shapes, source's shape and strides have to be
- // explicitly provided.
- SmallVector<Value> sourceDims;
- unsigned srcRank = srcTy.getRank();
- for (unsigned i = 0; i < srcRank; ++i)
- sourceDims.push_back(memref::DimOp::create(rewriter, loc, src, i));
-
- for (auto [idx, shape] : llvm::enumerate(srcTy.getShape())) {
- if (shape == ShapedType::kDynamic)
- mixedShapes.push_back(sourceDims[idx]);
- else
- mixedShapes.push_back(rewriter.getI64IntegerAttr(shape));
- }
-
- // Compute strides in reverse order.
- Value accStride = arith::ConstantIndexOp::create(rewriter, loc, 1);
- // Last stride is guaranteed to be static and unit.
- mixedStrides.push_back(rewriter.getI64IntegerAttr(1));
- for (int i = static_cast<int>(strides.size()) - 2; i >= 0; --i) {
- accStride =
- arith::MulIOp::create(rewriter, loc, accStride, sourceDims[i + 1]);
- if (strides[i] == ShapedType::kDynamic)
- mixedStrides.push_back(accStride);
- else
- mixedStrides.push_back(rewriter.getI64IntegerAttr(strides[i]));
- }
- std::reverse(mixedStrides.begin(), mixedStrides.end());
-}
-
static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter,
Location loc,
xegpu::TensorDescType descType,
@@ -143,46 +108,38 @@ static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter,
if (srcTy.hasStaticShape())
ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src);
else {
- SmallVector<OpFoldResult> mixedShapes;
- SmallVector<OpFoldResult> mixedStrides;
- computeMixedShapesStrides(rewriter, loc, mixedShapes, mixedStrides, strides,
- src);
+ // In case of any dynamic shapes, source's shape and strides have to be
+ // explicitly provided.
+ SmallVector<Value> sourceDims;
+ unsigned srcRank = srcTy.getRank();
+ for (unsigned i = 0; i < srcRank; ++i)
+ sourceDims.push_back(memref::DimOp::create(rewriter, loc, src, i));
- ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
- mixedShapes, mixedStrides);
- }
-
- return ndDesc;
-}
-
-static xegpu::CreateNdDescOp
-createNdDescriptor(PatternRewriter &rewriter, Location loc,
- xegpu::TensorDescType descType, TypedValue<MemRefType> src,
- Operation::operand_range offsets) {
- MemRefType srcTy = src.getType();
- auto [strides, offset] = srcTy.getStridesAndOffset();
-
- xegpu::CreateNdDescOp ndDesc;
- if (srcTy.hasStaticShape()) {
- ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
- getAsOpFoldResult(offsets));
- } else {
- SmallVector<OpFoldResult> mixedOffsets;
- for (Value offset : offsets) {
- std::optional<int64_t> staticVal = getConstantIntValue(offset);
- if (staticVal)
- mixedOffsets.push_back(rewriter.getI64IntegerAttr(staticVal.value()));
+ SmallVector<OpFoldResult> mixedShapes;
+ for (auto [idx, shape] : llvm::enumerate(srcTy.getShape())) {
+ if (shape == ShapedType::kDynamic)
+ mixedShapes.push_back(sourceDims[idx]);
else
- mixedOffsets.push_back(offset);
+ mixedShapes.push_back(rewriter.getI64IntegerAttr(shape));
}
- SmallVector<OpFoldResult> mixedShapes;
+ // Compute strides in reverse order.
SmallVector<OpFoldResult> mixedStrides;
- computeMixedShapesStrides(rewriter, loc, mixedShapes, mixedStrides, strides,
- src);
+ Value accStride = arith::ConstantIndexOp::create(rewriter, loc, 1);
+ // Last stride is guaranteed to be static and unit.
+ mixedStrides.push_back(rewriter.getI64IntegerAttr(1));
+ for (int i = static_cast<int>(strides.size()) - 2; i >= 0; --i) {
+ accStride =
+ arith::MulIOp::create(rewriter, loc, accStride, sourceDims[i + 1]);
+ if (strides[i] == ShapedType::kDynamic)
+ mixedStrides.push_back(accStride);
+ else
+ mixedStrides.push_back(rewriter.getI64IntegerAttr(strides[i]));
+ }
+ std::reverse(mixedStrides.begin(), mixedStrides.end());
- ndDesc = xegpu::CreateNdDescOp::create(
- rewriter, loc, descType, src, mixedOffsets, mixedShapes, mixedStrides);
+ ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
+ mixedShapes, mixedStrides);
}
return ndDesc;
@@ -564,29 +521,15 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
ArrayRef<int64_t>{1, 0});
// By default, no specific caching policy is assigned.
xegpu::CachePolicyAttr hint = nullptr;
- xegpu::LoadNdOp loadOp;
-
- if (vecTy.getRank() == readOp.getBase().getType().getRank()) {
- xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
- rewriter, loc, descType,
- dyn_cast<TypedValue<MemRefType>>(readOp.getBase()));
-
- loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
- getAsOpFoldResult(readOp.getIndices()),
- /*packed=*/nullptr, transposeAttr,
- /*l1_hint=*/hint,
- /*l2_hint=*/hint, /*l3_hint=*/hint);
- } else {
- xegpu::CreateNdDescOp ndDesc =
- createNdDescriptor(rewriter, loc, descType,
- dyn_cast<TypedValue<MemRefType>>(readOp.getBase()),
- readOp.getIndices());
-
- loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
- /*packed=*/nullptr, transposeAttr,
- /*l1_hint=*/hint,
- /*l2_hint=*/hint, /*l3_hint=*/hint);
- }
+ xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
+ rewriter, loc, descType,
+ dyn_cast<TypedValue<MemRefType>>(readOp.getBase()));
+
+ auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
+ getAsOpFoldResult(readOp.getIndices()),
+ /*packed=*/nullptr, transposeAttr,
+ /*l1_hint=*/hint,
+ /*l2_hint=*/hint, /*l3_hint=*/hint);
rewriter.replaceOp(readOp, loadOp);
return success();
@@ -630,28 +573,15 @@ struct TransferWriteLowering
xegpu::MemorySpace::Global);
// By default, no specific caching policy is assigned.
xegpu::CachePolicyAttr hint = nullptr;
- xegpu::StoreNdOp storeOp;
- if (vecTy.getRank() == writeOp.getBase().getType().getRank()) {
- xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
- rewriter, loc, descType,
- dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()));
-
- storeOp =
- xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc,
- getAsOpFoldResult(writeOp.getIndices()),
- /*l1_hint=*/hint,
- /*l2_hint=*/hint, /*l3_hint=*/hint);
- } else {
- xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
- rewriter, loc, descType,
- dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()),
- writeOp.getIndices());
-
- storeOp =
- xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc,
- /*l1_hint=*/hint,
- /*l2_hint=*/hint, /*l3_hint=*/hint);
- }
+ xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
+ rewriter, loc, descType,
+ dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()));
+
+ auto storeOp =
+ xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc,
+ getAsOpFoldResult(writeOp.getIndices()),
+ /*l1_hint=*/hint,
+ /*l2_hint=*/hint, /*l3_hint=*/hint);
rewriter.replaceOp(writeOp, storeOp);
return success();
@@ -743,25 +673,13 @@ struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1,
boundaryCheck, xegpu::MemorySpace::Global);
- xegpu::LoadNdOp loadNdOp;
-
- if (vecTy.getRank() == loadOp.getBase().getType().getRank()) {
- xegpu::CreateNdDescOp ndDesc =
- createNdDescriptor(rewriter, loc, descType, loadOp.getBase());
- loadNdOp = xegpu::LoadNdOp::create(
- rewriter, loc, vecTy, ndDesc, getAsOpFoldResult(loadOp.getIndices()),
- /*packed=*/nullptr, /*transpose=*/nullptr,
- /*l1_hint=*/hint,
- /*l2_hint=*/hint, /*l3_hint=*/hint);
- } else {
- xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
- rewriter, loc, descType, loadOp.getBase(), loadOp.getIndices());
- loadNdOp =
- xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
- /*packed=*/nullptr, /*transpose=*/nullptr,
- /*l1_hint=*/hint,
- /*l2_hint=*/hint, /*l3_hint=*/hint);
- }
+ xegpu::CreateNdDescOp ndDesc =
+ createNdDescriptor(rewriter, loc, descType, loadOp.getBase());
+ auto loadNdOp = xegpu::LoadNdOp::create(
+ rewriter, loc, vecTy, ndDesc, getAsOpFoldResult(loadOp.getIndices()),
+ /*packed=*/nullptr, /*transpose=*/nullptr,
+ /*l1_hint=*/hint,
+ /*l2_hint=*/hint, /*l3_hint=*/hint);
rewriter.replaceOp(loadOp, loadNdOp);
return success();
@@ -789,24 +707,14 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
// By default, no specific caching policy is assigned.
xegpu::CachePolicyAttr hint = nullptr;
- xegpu::StoreNdOp storeNdOp;
- if (vecTy.getRank() == storeOp.getBase().getType().getRank()) {
- xegpu::CreateNdDescOp ndDesc =
- createNdDescriptor(rewriter, loc, descType, storeOp.getBase());
-
- storeNdOp =
- xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc,
- getAsOpFoldResult(storeOp.getIndices()),
- /*l1_hint=*/hint,
- /*l2_hint=*/hint, /*l3_hint=*/hint);
- } else {
- xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
- rewriter, loc, descType, storeOp.getBase(), storeOp.getIndices());
-
- storeNdOp = xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc,
- /*l1_hint=*/hint,
- /*l2_hint=*/hint, /*l3_hint=*/hint);
- }
+ xegpu::CreateNdDescOp ndDesc =
+ createNdDescriptor(rewriter, loc, descType, storeOp.getBase());
+
+ auto storeNdOp =
+ xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc,
+ getAsOpFoldResult(storeOp.getIndices()),
+ /*l1_hint=*/hint,
+ /*l2_hint=*/hint, /*l3_hint=*/hint);
rewriter.replaceOp(storeOp, storeNdOp);
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 8ed8b26dd2a0e..b565e39464b52 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -475,14 +475,30 @@ LogicalResult PrefetchNdOp::verify() {
if (!isReadHintOrNone(getL3HintAttr()))
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
- int64_t tDescRank = tdescTy.getRank();
- int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
- int64_t constOffsetSize =
- getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
- if (((offsetSize != 0) && (offsetSize != tDescRank)) ||
- ((constOffsetSize != 0) && (constOffsetSize != tDescRank)))
- return emitOpError(
- "Mismatched ranks between offsets and tensor descriptor");
+ auto tDesc = getTensorDesc();
+ if (auto createTDescOp = tDesc.getDefiningOp<CreateNdDescOp>()) {
+ // If CreateNdDescOp is available, we can further
+ // check the offsets rank against the source rank.
+ auto staticSource = createTDescOp.getConstShapeAttr();
+ int64_t sourceRank;
+ if (!staticSource || staticSource.empty()) {
+ auto sourceTy = dyn_cast<MemRefType>(createTDescOp.getSourceType());
+ sourceRank = sourceTy.getRank();
+ } else
+ sourceRank = staticSource.size();
+
+ int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
+ int64_t constOffsetSize =
+ getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
+ auto tDescRank = tdescTy.getRank();
+ bool sourceRankMismatch = ((offsetSize != 0) && (offsetSize != sourceRank)) ||
+ ((constOffsetSize != 0) && (constOffsetSize != sourceRank));
+ bool tdescRankMismatch = ((offsetSize != 0) && (offsetSize != tDescRank)) ||
+ ((constOffsetSize != 0) && (constOffsetSize != tDescRank));
+ if (sourceRankMismatch && tdescRankMismatch)
+ return emitOpError(
+ "Offsets rank must match either the source or the TensorDesc rank.");
+ }
return success();
}
@@ -600,14 +616,30 @@ LogicalResult LoadNdOp::verify() {
<< " is not consistent with tensor descriptor "
<< tdescTy;
- int64_t tDescRank = tdescTy.getRank();
- int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
- int64_t constOffsetSize =
- getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
- if (((offsetSize != 0) && (offsetSize != tDescRank)) ||
- ((constOffsetSize != 0) && (constOffsetSize != tDescRank)))
- return emitOpError(
- "Mismatched ranks between offsets and tensor descriptor");
+ auto tDesc = getTensorDesc();
+ if (auto createTDescOp = tDesc.getDefiningOp<CreateNdDescOp>()) {
+ // If CreateNdDescOp is available, we can further
+ // check the offsets rank against the source rank.
+ auto staticSource = createTDescOp.getConstShapeAttr();
+ int64_t sourceRank;
+ if (!staticSource || staticSource.empty()) {
+ auto sourceTy = dyn_cast<MemRefType>(createTDescOp.getSourceType());
+ sourceRank = sourceTy.getRank();
+ } else
+ sourceRank = staticSource.size();
+
+ int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
+ int64_t constOffsetSize =
+ getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
+ auto tDescRank = tdescTy.getRank();
+ bool sourceRankMismatch = ((offsetSize != 0) && (offsetSize != sourceRank)) ||
+ ((constOffsetSize != 0) && (constOffsetSize != sourceRank));
+ bool tdescRankMismatch = ((offsetSize != 0) && (offsetSize != tDescRank)) ||
+ ((constOffsetSize != 0) && (constOffsetSize != tDescRank));
+ if (sourceRankMismatch && tdescRankMismatch)
+ return emitOpError(
+ "Offsets rank must match either the source or the TensorDesc rank.");
+ }
return success();
}
@@ -694,14 +726,30 @@ LogicalResult StoreNdOp::verify() {
<< " is not consistent with tensor descriptor "
<< dstTy;
- int64_t tDescRank = dstTy.getRank();
- int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
- int64_t constOffsetSize =
- getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
- if (((offsetSize != 0) && (offsetSize != tDescRank)) ||
- ((constOffsetSize != 0) && (constOffsetSize != tDescRank)))
- return emitOpError(
- "Mismatched ranks between offsets and tensor descriptor");
+ auto tDesc = getTensorDesc();
+ if (auto createTDescOp = tDesc.getDefiningOp<CreateNdDescOp>()) {
+ // If CreateNdDescOp is available, we can further
+ // check the offsets rank against the source rank.
+ auto staticSource = createTDescOp.getConstShapeAttr();
+ int64_t sourceRank;
+ if (!staticSource || staticSource.empty()) {
+ auto sourceTy = dyn_cast<MemRefType>(createTDescOp.getSourceType());
+ sourceRank = sourceTy.getRank();
+ } else
+ sourceRank = staticSource.size();
+
+ int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
+ int64_t constOffsetSize =
+ getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
+ auto tDescRank = dstTy.getRank();
+ bool sourceRankMismatch = ((offsetSize != 0) && (offsetSize != sourceRank)) ||
+ ((constOffsetSize != 0) && (constOffsetSize != sourceRank));
+ bool tdescRankMismatch = ((offsetSize != 0) && (offsetSize != tDescRank)) ||
+ ((constOffsetSize != 0) && (constOffsetSize != tDescRank));
+ if (sourceRankMismatch && tdescRankMismatch)
+ return emitOpError(
+ "Offsets rank must match either the source or the TensorDesc rank.");
+ }
return success();
}
diff --git a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
index c7c0485768b99..b5fb2c4aa3e27 100644
--- a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
@@ -10,10 +10,10 @@ func.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vecto
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME: %[[SRC]]
// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
// CHECK-SAME: boundary_check = false
-// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf32>
+// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8xf32>
// CHECK: return %[[VEC]]
// -----
@@ -29,9 +29,9 @@ func.func @load_2D_vector(%source: memref<8x16x32xf32>,
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME: %[[SRC]]
// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
// CHECK: return %[[VEC]]
// -----
@@ -53,10 +53,10 @@ func.func @load_dynamic_source(%source: memref<?x?x?xf32>,
// CHECK-DAG: %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
// CHECK-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
// CHECK: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
-// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
// CHECK-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
// CHECK: return %[[VEC]]
// -----
diff --git a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
index 19240abe1e75c..57e754f7d7c00 100644
--- a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
@@ -12,10 +12,10 @@ func.func @store_1D_vector(%vec: vector<8xf32>,
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME: %[[SRC]]
// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
// CHECK-SAME: boundary_check = false
-// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8xf32>
+// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8xf32>
// -----
@@ -31,9 +31,9 @@ func.func @store_2D_vector(%vec: vector<8x16xf32>,
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME: %[[SRC]]
// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
// -----
@@ -55,10 +55,10 @@ func.func @store_dynamic_source(%vec: vector<8x16xf32>,
// CHECK-DAG: %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
// CHECK-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
// CHECK: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
-// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
// CHECK-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
// -----
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
index 72bdab0a4db3a..78a2692119142 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
@@ -13,10 +13,10 @@ gpu.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vector
// LOAD-ND-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// LOAD-ND-SAME: %[[OFFSET:.+]]: index
// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// LOAD-ND-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// LOAD-ND-SAME: %[[SRC]]
// LOAD-ND-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
// LOAD-ND-SAME: boundary_check = false
-// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf32>
+// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8xf32>
// LOAD-ND: return %[[VEC]]
// LOAD-GATHER-LABEL: @load_1D_vector(
@@ -47,10 +47,10 @@ gpu.func @load_2D_vector(%source: memref<8x16x32xf32>,
// LOAD-ND-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// LOAD-ND-SAME: %[[OFFSET:.+]]: index
// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// LOAD-ND-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// LOAD-ND-SAME: %[[SRC]]
// LOAD-ND-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
// LOAD-ND-SAME: boundary_check = false
-// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
// LOAD-ND: return %[[VEC]]
// LOAD-GATHER-LABEL: @load_2D_vector(
@@ -151,8 +151,8 @@ gpu.func @load_dynamic_source(%source: memref<?x?x?xf32>,
// LOAD-ND-DAG: %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
// LOAD-ND-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
// LOAD-ND: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
-// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET:.+]], %[[OFFSET:.+]], %[[OFFSET:.+]]]
-// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
+// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET:.+]], %[[OFFSET:.+]], %[[OFFSET:.+]]]{{.*}}-> vector<8x16xf32>
// LOAD-ND: return %[[VEC]]
@@ -186,8 +186,8 @@ gpu.func @load_dynamic_source2(%source: memref<?x8x16xf32>,
// LOAD-ND-LABEL: @load_dynamic_source2(
// LOAD-ND-DAG: %[[C0:.+]] = arith.constant 0 : index
// LOAD-ND-DAG: %[[DIM:.+]] = memref.dim %{{.*}}, %[[C0]] : memref<?x8x16xf32>
-// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}], shape : [%[[DIM]], 8, 16], strides : [128, 16, 1] : memref<?x8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
-// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]] : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>> -> vector<8x16xf32>
+// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %{{.*}}, shape : [%[[DIM]], 8, 16], strides : [128, 16, 1] : memref<?x8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
+// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%{{.*}}, %{{.*}}, %{{.*}}] : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>> -> vector<8x16xf32>
// LOAD-ND: return %[[VEC]] : vector<8x16xf32>
// LOAD-GATHER-LABEL: @load_dynamic_source2(
@@ -419,10 +419,10 @@ gpu.func @load_from_subview(%source: memref<4096x4096xf16>, %off1: index, %off2:
// LOAD-ND-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
// LOAD-ND: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// LOAD-ND-SAME: %[[SUBVIEW]][%[[OFF2]], %[[OFF2]]]
+// LOAD-ND-SAME: %[[SUBVIEW]]
// LOAD-ND-SAME: memref<256x256xf16, strided<[4096, 1], offset: ?>> -> !xegpu.tensor_desc<8xf16,
// LOAD-ND-SAME: boundary_check = false
-// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf16>
+// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF2]], %[[OFF2]]]{{.*}}-> vector<8xf16>
// LOAD-ND: return %[[VEC]]
// LOAD-GATHER-LABEL: @load_from_subview(
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
index ca3bbc11a5180..e1b754f952bbe 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
@@ -16,10 +16,10 @@ gpu.func @store_1D_vector(%vec: vector<8xf32>,
// STORE-ND-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// STORE-ND-SAME: %[[OFFSET:.+]]: index
// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// STORE-ND-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// STORE-ND-SAME: %[[SRC]]
// STORE-ND-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
// STORE-ND-SAME: boundary_check = false
-// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8xf32>
+// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8xf32>
// STORE-SCATTER-LABEL: @store_1D_vector(
// STORE-SCATTER-SAME: %[[VEC:.+]]: vector<8xf32>,
@@ -50,10 +50,10 @@ gpu.func @store_2D_vector(%vec: vector<8x16xf32>,
// STORE-ND-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// STORE-ND-SAME: %[[OFFSET:.+]]: index
// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// STORE-ND-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// STORE-ND-SAME: %[[SRC]]
// STORE-ND-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
// STORE-ND-SAME: boundary_check = false
-// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
// STORE-SCATTER-LABEL: @store_2D_vector(
// STORE-SCATTER-SAME: %[[VEC:.+]]: vector<8x16xf32>,
@@ -91,10 +91,10 @@ gpu.func @store_dynamic_source(%vec: vector<8x16xf32>,
// STORE-ND-DAG: %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
// STORE-ND-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
// STORE-ND: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
-// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
// STORE-ND-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
// STORE-ND-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32
-// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
// STORE-SCATTER-LABEL: @store_dynamic_source(
// STORE-SCATTER-SAME: %[[VEC:.+]]: vector<8x16xf32>,
@@ -301,10 +301,10 @@ gpu.func @store_to_subview(%vec: vector<8xf16>,
// STORE-ND: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1]
// STORE-ND-SAME: : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// STORE-ND-SAME: %[[SUBVIEW]][%[[OFF2]], %[[OFF2]]]
+// STORE-ND-SAME: %[[SUBVIEW]]
// STORE-ND-SAME: memref<256x256xf16, strided<[4096, 1], offset: ?>> -> !xegpu.tensor_desc<8xf16,
// STORE-ND-SAME: boundary_check = false
-// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8xf16>
+// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFF2]], %[[OFF2]]] : vector<8xf16>
// STORE-SCATTER-LABEL: @store_to_subview(
// STORE-SCATTER-SAME: %[[VEC:.+]]: vector<8xf16>,
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index ebbe3ce0ec0d0..00a586dee1f51 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -135,7 +135,7 @@ func.func @subgroup_load_nd_9(%src: memref<4x8x16xf16>) {
// -----
func.func @subgroup_load_nd_offset_1(%src: memref<4x8x16xf16>, %x : index) {
%1 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<16xf16>
-// expected-error at +1 {{Mismatched ranks between offsets and tensor descriptor}}
+// expected-error at +1 {{Offsets rank must match either the source or the TensorDesc rank.}}
%2 = xegpu.load_nd %1[0, 0] : !xegpu.tensor_desc<16xf16> -> vector<16xf16>
return
}
@@ -143,7 +143,7 @@ func.func @subgroup_load_nd_offset_1(%src: memref<4x8x16xf16>, %x : index) {
// -----
func.func @subgroup_load_nd_offset_2(%src: memref<4x8x16xf16>, %x : index) {
%3 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
- // expected-error at +1 {{Mismatched ranks between offsets and tensor descriptor}}
+ // expected-error at +1 {{Offsets rank must match either the source or the TensorDesc rank.}}
xegpu.prefetch_nd %3[0] : !xegpu.tensor_desc<8x16xf16>
return
}
@@ -152,11 +152,19 @@ func.func @subgroup_load_nd_offset_2(%src: memref<4x8x16xf16>, %x : index) {
func.func @subgroup_load_nd_offset_3(%src: memref<4x8x16xf16>, %x : index) {
%3 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
%5 = xegpu.load_nd %3[0, 0] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
- // expected-error at +1 {{Mismatched ranks between offsets and tensor descriptor}}
+ // expected-error at +1 {{Offsets rank must match either the source or the TensorDesc rank.}}
xegpu.store_nd %5, %3[%x] : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16>
return
}
+// -----
+func.func @subgroup_load_nd_offset_4(%src: memref<4x8x16xf16>, %x : index) {
+ %3 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
+ // expected-error at +1 {{Offsets rank must match either the source or the TensorDesc rank.}}
+ %5 = xegpu.load_nd %3[0] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
+ return
+}
+
// -----
func.func @load_nd_layout(%src: memref<24x32xf32>) {
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16xf32>
>From 65f57c7f14a844ec03928887c8c867ca4d6324d5 Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Thu, 16 Oct 2025 10:46:28 +0000
Subject: [PATCH 03/13] Apply formatting
Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
.../VectorToXeGPU/VectorToXeGPU.cpp | 39 +++++++++----------
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 24 +++++++-----
2 files changed, 34 insertions(+), 29 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index f3dcb31f6b0be..7f11d427191e5 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -521,15 +521,15 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
ArrayRef<int64_t>{1, 0});
// By default, no specific caching policy is assigned.
xegpu::CachePolicyAttr hint = nullptr;
- xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
- rewriter, loc, descType,
- dyn_cast<TypedValue<MemRefType>>(readOp.getBase()));
-
- auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
- getAsOpFoldResult(readOp.getIndices()),
- /*packed=*/nullptr, transposeAttr,
- /*l1_hint=*/hint,
- /*l2_hint=*/hint, /*l3_hint=*/hint);
+ xegpu::CreateNdDescOp ndDesc =
+ createNdDescriptor(rewriter, loc, descType,
+ dyn_cast<TypedValue<MemRefType>>(readOp.getBase()));
+
+ auto loadOp = xegpu::LoadNdOp::create(
+ rewriter, loc, vecTy, ndDesc, getAsOpFoldResult(readOp.getIndices()),
+ /*packed=*/nullptr, transposeAttr,
+ /*l1_hint=*/hint,
+ /*l2_hint=*/hint, /*l3_hint=*/hint);
rewriter.replaceOp(readOp, loadOp);
return success();
@@ -573,15 +573,15 @@ struct TransferWriteLowering
xegpu::MemorySpace::Global);
// By default, no specific caching policy is assigned.
xegpu::CachePolicyAttr hint = nullptr;
- xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
- rewriter, loc, descType,
- dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()));
+ xegpu::CreateNdDescOp ndDesc =
+ createNdDescriptor(rewriter, loc, descType,
+ dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()));
auto storeOp =
xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc,
- getAsOpFoldResult(writeOp.getIndices()),
- /*l1_hint=*/hint,
- /*l2_hint=*/hint, /*l3_hint=*/hint);
+ getAsOpFoldResult(writeOp.getIndices()),
+ /*l1_hint=*/hint,
+ /*l2_hint=*/hint, /*l3_hint=*/hint);
rewriter.replaceOp(writeOp, storeOp);
return success();
@@ -710,11 +710,10 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
xegpu::CreateNdDescOp ndDesc =
createNdDescriptor(rewriter, loc, descType, storeOp.getBase());
- auto storeNdOp =
- xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc,
- getAsOpFoldResult(storeOp.getIndices()),
- /*l1_hint=*/hint,
- /*l2_hint=*/hint, /*l3_hint=*/hint);
+ auto storeNdOp = xegpu::StoreNdOp::create(
+ rewriter, loc, vector, ndDesc, getAsOpFoldResult(storeOp.getIndices()),
+ /*l1_hint=*/hint,
+ /*l2_hint=*/hint, /*l3_hint=*/hint);
rewriter.replaceOp(storeOp, storeNdOp);
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index b565e39464b52..0435216e306af 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -484,16 +484,18 @@ LogicalResult PrefetchNdOp::verify() {
if (!staticSource || staticSource.empty()) {
auto sourceTy = dyn_cast<MemRefType>(createTDescOp.getSourceType());
sourceRank = sourceTy.getRank();
- } else
+ } else
sourceRank = staticSource.size();
int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
int64_t constOffsetSize =
getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
auto tDescRank = tdescTy.getRank();
- bool sourceRankMismatch = ((offsetSize != 0) && (offsetSize != sourceRank)) ||
+ bool sourceRankMismatch =
+ ((offsetSize != 0) && (offsetSize != sourceRank)) ||
((constOffsetSize != 0) && (constOffsetSize != sourceRank));
- bool tdescRankMismatch = ((offsetSize != 0) && (offsetSize != tDescRank)) ||
+ bool tdescRankMismatch =
+ ((offsetSize != 0) && (offsetSize != tDescRank)) ||
((constOffsetSize != 0) && (constOffsetSize != tDescRank));
if (sourceRankMismatch && tdescRankMismatch)
return emitOpError(
@@ -625,16 +627,18 @@ LogicalResult LoadNdOp::verify() {
if (!staticSource || staticSource.empty()) {
auto sourceTy = dyn_cast<MemRefType>(createTDescOp.getSourceType());
sourceRank = sourceTy.getRank();
- } else
+ } else
sourceRank = staticSource.size();
int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
int64_t constOffsetSize =
getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
auto tDescRank = tdescTy.getRank();
- bool sourceRankMismatch = ((offsetSize != 0) && (offsetSize != sourceRank)) ||
+ bool sourceRankMismatch =
+ ((offsetSize != 0) && (offsetSize != sourceRank)) ||
((constOffsetSize != 0) && (constOffsetSize != sourceRank));
- bool tdescRankMismatch = ((offsetSize != 0) && (offsetSize != tDescRank)) ||
+ bool tdescRankMismatch =
+ ((offsetSize != 0) && (offsetSize != tDescRank)) ||
((constOffsetSize != 0) && (constOffsetSize != tDescRank));
if (sourceRankMismatch && tdescRankMismatch)
return emitOpError(
@@ -735,16 +739,18 @@ LogicalResult StoreNdOp::verify() {
if (!staticSource || staticSource.empty()) {
auto sourceTy = dyn_cast<MemRefType>(createTDescOp.getSourceType());
sourceRank = sourceTy.getRank();
- } else
+ } else
sourceRank = staticSource.size();
int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
int64_t constOffsetSize =
getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
auto tDescRank = dstTy.getRank();
- bool sourceRankMismatch = ((offsetSize != 0) && (offsetSize != sourceRank)) ||
+ bool sourceRankMismatch =
+ ((offsetSize != 0) && (offsetSize != sourceRank)) ||
((constOffsetSize != 0) && (constOffsetSize != sourceRank));
- bool tdescRankMismatch = ((offsetSize != 0) && (offsetSize != tDescRank)) ||
+ bool tdescRankMismatch =
+ ((offsetSize != 0) && (offsetSize != tDescRank)) ||
((constOffsetSize != 0) && (constOffsetSize != tDescRank));
if (sourceRankMismatch && tdescRankMismatch)
return emitOpError(
>From 49d38a079b72d14c5cd40cb5096814935d78fe9a Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Thu, 16 Oct 2025 11:06:12 +0000
Subject: [PATCH 04/13] generalize 'offsets-check'
Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 120 ++++++++-----------------
mlir/test/Dialect/XeGPU/invalid.mlir | 8 --
2 files changed, 39 insertions(+), 89 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 0435216e306af..0624e8c4a6a38 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -121,6 +121,39 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
return success();
}
+// Verify that number of offsets matches either the source rank or the tdesc
+// rank.
+static LogicalResult
+isValidNdOffset(TypedValue<TensorDescType> tDesc,
+ std::optional<llvm::ArrayRef<long int>> constOffsets,
+ int64_t offsetSize,
+ function_ref<InFlightDiagnostic()> emitError) {
+ if (auto createTDescOp = tDesc.getDefiningOp<CreateNdDescOp>()) {
+ // If CreateNdDescOp is available, we can further
+ // check the offsets rank against the source rank.
+ auto staticSource = createTDescOp.getConstShapeAttr();
+ int64_t sourceRank;
+ if (!staticSource || staticSource.empty()) {
+ auto sourceTy = dyn_cast<MemRefType>(createTDescOp.getSourceType());
+ sourceRank = sourceTy.getRank();
+ } else
+ sourceRank = staticSource.size();
+
+ int64_t constOffsetSize = constOffsets ? constOffsets->size() : 0;
+ auto tDescRank = tDesc.getType().getRank();
+ bool sourceRankMismatch =
+ ((offsetSize != 0) && (offsetSize != sourceRank)) ||
+ ((constOffsetSize != 0) && (constOffsetSize != sourceRank));
+ bool tdescRankMismatch =
+ ((offsetSize != 0) && (offsetSize != tDescRank)) ||
+ ((constOffsetSize != 0) && (constOffsetSize != tDescRank));
+ if (sourceRankMismatch && tdescRankMismatch)
+ return emitError() << "Offsets rank must match either the source or the "
+ "TensorDesc rank.";
+ }
+ return success();
+}
+
static LogicalResult
isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy,
VectorType valueTy, int64_t chunkSize,
@@ -476,33 +509,8 @@ LogicalResult PrefetchNdOp::verify() {
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
auto tDesc = getTensorDesc();
- if (auto createTDescOp = tDesc.getDefiningOp<CreateNdDescOp>()) {
- // If CreateNdDescOp is available, we can further
- // check the offsets rank against the source rank.
- auto staticSource = createTDescOp.getConstShapeAttr();
- int64_t sourceRank;
- if (!staticSource || staticSource.empty()) {
- auto sourceTy = dyn_cast<MemRefType>(createTDescOp.getSourceType());
- sourceRank = sourceTy.getRank();
- } else
- sourceRank = staticSource.size();
-
- int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
- int64_t constOffsetSize =
- getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
- auto tDescRank = tdescTy.getRank();
- bool sourceRankMismatch =
- ((offsetSize != 0) && (offsetSize != sourceRank)) ||
- ((constOffsetSize != 0) && (constOffsetSize != sourceRank));
- bool tdescRankMismatch =
- ((offsetSize != 0) && (offsetSize != tDescRank)) ||
- ((constOffsetSize != 0) && (constOffsetSize != tDescRank));
- if (sourceRankMismatch && tdescRankMismatch)
- return emitOpError(
- "Offsets rank must match either the source or the TensorDesc rank.");
- }
-
- return success();
+ return isValidNdOffset(tDesc, getConstOffsets(), getMixedOffsets().size(),
+ [&]() { return emitOpError(); });
}
//===----------------------------------------------------------------------===//
@@ -619,33 +627,8 @@ LogicalResult LoadNdOp::verify() {
<< tdescTy;
auto tDesc = getTensorDesc();
- if (auto createTDescOp = tDesc.getDefiningOp<CreateNdDescOp>()) {
- // If CreateNdDescOp is available, we can further
- // check the offsets rank against the source rank.
- auto staticSource = createTDescOp.getConstShapeAttr();
- int64_t sourceRank;
- if (!staticSource || staticSource.empty()) {
- auto sourceTy = dyn_cast<MemRefType>(createTDescOp.getSourceType());
- sourceRank = sourceTy.getRank();
- } else
- sourceRank = staticSource.size();
-
- int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
- int64_t constOffsetSize =
- getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
- auto tDescRank = tdescTy.getRank();
- bool sourceRankMismatch =
- ((offsetSize != 0) && (offsetSize != sourceRank)) ||
- ((constOffsetSize != 0) && (constOffsetSize != sourceRank));
- bool tdescRankMismatch =
- ((offsetSize != 0) && (offsetSize != tDescRank)) ||
- ((constOffsetSize != 0) && (constOffsetSize != tDescRank));
- if (sourceRankMismatch && tdescRankMismatch)
- return emitOpError(
- "Offsets rank must match either the source or the TensorDesc rank.");
- }
-
- return success();
+ return isValidNdOffset(tDesc, getConstOffsets(), getMixedOffsets().size(),
+ [&]() { return emitOpError(); });
}
//===----------------------------------------------------------------------===//
@@ -731,33 +714,8 @@ LogicalResult StoreNdOp::verify() {
<< dstTy;
auto tDesc = getTensorDesc();
- if (auto createTDescOp = tDesc.getDefiningOp<CreateNdDescOp>()) {
- // If CreateNdDescOp is available, we can further
- // check the offsets rank against the source rank.
- auto staticSource = createTDescOp.getConstShapeAttr();
- int64_t sourceRank;
- if (!staticSource || staticSource.empty()) {
- auto sourceTy = dyn_cast<MemRefType>(createTDescOp.getSourceType());
- sourceRank = sourceTy.getRank();
- } else
- sourceRank = staticSource.size();
-
- int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
- int64_t constOffsetSize =
- getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
- auto tDescRank = dstTy.getRank();
- bool sourceRankMismatch =
- ((offsetSize != 0) && (offsetSize != sourceRank)) ||
- ((constOffsetSize != 0) && (constOffsetSize != sourceRank));
- bool tdescRankMismatch =
- ((offsetSize != 0) && (offsetSize != tDescRank)) ||
- ((constOffsetSize != 0) && (constOffsetSize != tDescRank));
- if (sourceRankMismatch && tdescRankMismatch)
- return emitOpError(
- "Offsets rank must match either the source or the TensorDesc rank.");
- }
-
- return success();
+ return isValidNdOffset(tDesc, getConstOffsets(), getMixedOffsets().size(),
+ [&]() { return emitOpError(); });
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 00a586dee1f51..614f21bcebc48 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -157,14 +157,6 @@ func.func @subgroup_load_nd_offset_3(%src: memref<4x8x16xf16>, %x : index) {
return
}
-// -----
-func.func @subgroup_load_nd_offset_4(%src: memref<4x8x16xf16>, %x : index) {
- %3 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
- // expected-error at +1 {{Offsets rank must match either the source or the TensorDesc rank.}}
- %5 = xegpu.load_nd %3[0] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
- return
-}
-
// -----
func.func @load_nd_layout(%src: memref<24x32xf32>) {
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16xf32>
>From f9f73ad7c7921d0ce038fae526fce6df23c85e97 Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Thu, 16 Oct 2025 11:23:52 +0000
Subject: [PATCH 05/13] fix windows build
Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 0624e8c4a6a38..b3bdfc58abafc 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -125,7 +125,7 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
// rank.
static LogicalResult
isValidNdOffset(TypedValue<TensorDescType> tDesc,
- std::optional<llvm::ArrayRef<long int>> constOffsets,
+ std::optional<llvm::ArrayRef<int64_t>> constOffsets,
int64_t offsetSize,
function_ref<InFlightDiagnostic()> emitError) {
if (auto createTDescOp = tDesc.getDefiningOp<CreateNdDescOp>()) {
>From 2c75b294e9ea00e1d9cb12de9cf26ac7a5121ef5 Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Tue, 21 Oct 2025 22:56:31 +0000
Subject: [PATCH 06/13] add docs for new offset syntax
Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
.../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 45 +++++++++++++++++++
1 file changed, 45 insertions(+)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 426377fcf598f..93c9f305c080c 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -261,6 +261,21 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
: !xegpu.tensor_desc<8x16xf16>
```
+ The operation may take optional offsets for the tensor descriptor.
+ The number of offsets must be greater or equal to the rank of the tensor descriptor
+ and less than the rank of the source memref. The offsets are applied to the innermost
+ dimension of the source memref.
+
+ Examples:
+ ```mlir
+ %tdesc = xegpu.create_nd_tdesc %0: memref<2x8x32x32xf32> -> TensorDesc<8x16xf32>
+ // memref[0, 0, %off0, %off1]
+ xegpu.prefetch_nd %tdesc[%off0, %off1] : !xegpu.tensor_desc<8x16xf16>
+ // memref[0, %off0, %off1, %off2]
+ xegpu.prefetch_nd %tdesc[%off0, %off1, %off2] : !xegpu.tensor_desc<8x16xf16>
+ // memref[%off0, %off1, %off2, %off3]
+ xegpu.prefetch_nd %tdesc[%off0, %off1, %off2] : !xegpu.tensor_desc<8x16xf16>
+ ```
}];
let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
@@ -350,6 +365,21 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
: !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
```
+ The operation may take optional offsets for the tensor descriptor.
+ The number of offsets must be greater or equal to the rank of the tensor descriptor
+ and less than the rank of the source memref. The offsets are applied to the innermost
+ dimension of the source memref.
+
+ Examples:
+ ```mlir
+ %1 = xegpu.create_nd_tdesc %0: memref<2x8x32x32xf32> -> TensorDesc<8x16xf32>
+ // memref[0, 0, %off0, %off1]
+ xegpu.load_nd %1[%off0, %off1] : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
+ // memref[0, %off0, %off1, %off2]
+ xegpu.load_nd %1[%off0, %off1, %off2] : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
+ // memref[%off0, %off1, %off2, %off3]
+ xegpu.load_nd %1[%off0, %off1, %off2, %off3] : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
+ ```
}];
@@ -445,6 +475,21 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
: vector<8xf16>, !xegpu.tensor_desc<8x16xf16>
```
+ The operation may take optional offsets for the tensor descriptor.
+ The number of offsets must be greater or equal to the rank of the tensor descriptor
+ and less than the rank of the source memref. The offsets are applied to the innermost
+ dimension of the source memref.
+
+ Examples:
+ ```mlir
+ %2 = xegpu.create_nd_tdesc %0: memref<2x8x32x32xf32> -> TensorDesc<8x16xf32>
+ // memref[0, 0, %off0, %off1]
+ xegpu.store_nd %3, %2[%off0, %off1] : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16>
+ // memref[0, %off0, %off1, %off2]
+ xegpu.store_nd %3, %2[%off0, %off1, %off2] : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16>
+ // memref[%off0, %off1, %off2, %off3]
+ xegpu.store_nd %3, %2[%off0, %off1, %off2, %off3] : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16>
+ ```
}];
>From 15f3aa70fd6709bec1a9dfed73d3c44d2dc0acf2 Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Tue, 21 Oct 2025 22:56:59 +0000
Subject: [PATCH 07/13] Update validation to not depend on 'create_nd_tdesc' op
Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 29 ++++++--------------------
1 file changed, 6 insertions(+), 23 deletions(-)
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index b3bdfc58abafc..76640bb59be46 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -128,29 +128,12 @@ isValidNdOffset(TypedValue<TensorDescType> tDesc,
std::optional<llvm::ArrayRef<int64_t>> constOffsets,
int64_t offsetSize,
function_ref<InFlightDiagnostic()> emitError) {
- if (auto createTDescOp = tDesc.getDefiningOp<CreateNdDescOp>()) {
- // If CreateNdDescOp is available, we can further
- // check the offsets rank against the source rank.
- auto staticSource = createTDescOp.getConstShapeAttr();
- int64_t sourceRank;
- if (!staticSource || staticSource.empty()) {
- auto sourceTy = dyn_cast<MemRefType>(createTDescOp.getSourceType());
- sourceRank = sourceTy.getRank();
- } else
- sourceRank = staticSource.size();
-
- int64_t constOffsetSize = constOffsets ? constOffsets->size() : 0;
- auto tDescRank = tDesc.getType().getRank();
- bool sourceRankMismatch =
- ((offsetSize != 0) && (offsetSize != sourceRank)) ||
- ((constOffsetSize != 0) && (constOffsetSize != sourceRank));
- bool tdescRankMismatch =
- ((offsetSize != 0) && (offsetSize != tDescRank)) ||
- ((constOffsetSize != 0) && (constOffsetSize != tDescRank));
- if (sourceRankMismatch && tdescRankMismatch)
- return emitError() << "Offsets rank must match either the source or the "
- "TensorDesc rank.";
- }
+ int64_t constOffsetSize = constOffsets ? constOffsets->size() : 0;
+ auto tDescRank = tDesc.getType().getRank();
+ if (((offsetSize != 0) && (offsetSize < tDescRank)) ||
+ ((constOffsetSize != 0) && (constOffsetSize < tDescRank)))
+ return emitError() << "Offsets rank cannot be smaller than tensor "
+ "descriptor rank.";
return success();
}
>From ccf8b92429503fd2ee9fcf175004708f51e7fe86 Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Tue, 21 Oct 2025 22:57:19 +0000
Subject: [PATCH 08/13] Use memref.extract_strided_metadata to compute strides
Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
.../VectorToXeGPU/VectorToXeGPU.cpp | 21 ++++---------------
.../VectorToXeGPU/load-to-xegpu.mlir | 4 ++--
.../VectorToXeGPU/store-to-xegpu.mlir | 4 ++--
.../VectorToXeGPU/transfer-read-to-xegpu.mlir | 4 ++--
.../transfer-write-to-xegpu.mlir | 4 ++--
mlir/test/Dialect/XeGPU/invalid.mlir | 12 ++---------
6 files changed, 14 insertions(+), 35 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 7f11d427191e5..0f031be26cebc 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -105,9 +105,9 @@ static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter,
auto [strides, offset] = srcTy.getStridesAndOffset();
xegpu::CreateNdDescOp ndDesc;
- if (srcTy.hasStaticShape())
+ if (srcTy.hasStaticShape()) {
ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src);
- else {
+ } else {
// In case of any dynamic shapes, source's shape and strides have to be
// explicitly provided.
SmallVector<Value> sourceDims;
@@ -123,21 +123,8 @@ static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter,
mixedShapes.push_back(rewriter.getI64IntegerAttr(shape));
}
- // Compute strides in reverse order.
- SmallVector<OpFoldResult> mixedStrides;
- Value accStride = arith::ConstantIndexOp::create(rewriter, loc, 1);
- // Last stride is guaranteed to be static and unit.
- mixedStrides.push_back(rewriter.getI64IntegerAttr(1));
- for (int i = static_cast<int>(strides.size()) - 2; i >= 0; --i) {
- accStride =
- arith::MulIOp::create(rewriter, loc, accStride, sourceDims[i + 1]);
- if (strides[i] == ShapedType::kDynamic)
- mixedStrides.push_back(accStride);
- else
- mixedStrides.push_back(rewriter.getI64IntegerAttr(strides[i]));
- }
- std::reverse(mixedStrides.begin(), mixedStrides.end());
-
+ auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, src);
+ SmallVector<OpFoldResult> mixedStrides(meta.getStrides().begin(), meta.getStrides().end());
ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
mixedShapes, mixedStrides);
}
diff --git a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
index b5fb2c4aa3e27..1975c96bfe796 100644
--- a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
@@ -52,9 +52,9 @@ func.func @load_dynamic_source(%source: memref<?x?x?xf32>,
// CHECK-DAG: %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
// CHECK-DAG: %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
// CHECK-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
-// CHECK: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
+// CHECK: {{.*}} %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]]
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
-// CHECK-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
+// CHECK-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[STRIDES]]#0, %[[STRIDES]]#1, %c1]
// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
// CHECK: return %[[VEC]]
diff --git a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
index 57e754f7d7c00..63e78ca20bcee 100644
--- a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
@@ -54,9 +54,9 @@ func.func @store_dynamic_source(%vec: vector<8x16xf32>,
// CHECK-DAG: %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
// CHECK-DAG: %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
// CHECK-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
-// CHECK: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
+// CHECK: {{.*}} %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]]
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
-// CHECK-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
+// CHECK-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[STRIDES]]#0, %[[STRIDES]]#1, %c1]
// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
index 78a2692119142..81527a8111bb0 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
@@ -150,7 +150,7 @@ gpu.func @load_dynamic_source(%source: memref<?x?x?xf32>,
// LOAD-ND-DAG: %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
// LOAD-ND-DAG: %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
// LOAD-ND-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
-// LOAD-ND: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
+// LOAD-ND: {{.*}} %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]]
// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET:.+]], %[[OFFSET:.+]], %[[OFFSET:.+]]]{{.*}}-> vector<8x16xf32>
// LOAD-ND: return %[[VEC]]
@@ -186,7 +186,7 @@ gpu.func @load_dynamic_source2(%source: memref<?x8x16xf32>,
// LOAD-ND-LABEL: @load_dynamic_source2(
// LOAD-ND-DAG: %[[C0:.+]] = arith.constant 0 : index
// LOAD-ND-DAG: %[[DIM:.+]] = memref.dim %{{.*}}, %[[C0]] : memref<?x8x16xf32>
-// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %{{.*}}, shape : [%[[DIM]], 8, 16], strides : [128, 16, 1] : memref<?x8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
+// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %{{.*}}, shape : [%[[DIM]], 8, 16], strides : [%c128, %c16, %c1] : memref<?x8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%{{.*}}, %{{.*}}, %{{.*}}] : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>> -> vector<8x16xf32>
// LOAD-ND: return %[[VEC]] : vector<8x16xf32>
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
index e1b754f952bbe..83d33e1905f7c 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
@@ -90,9 +90,9 @@ gpu.func @store_dynamic_source(%vec: vector<8x16xf32>,
// STORE-ND-DAG: %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
// STORE-ND-DAG: %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
// STORE-ND-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
-// STORE-ND: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
+// STORE-ND: {{.*}} %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]]
// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
-// STORE-ND-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
+// STORE-ND-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[STRIDES]]#0, %[[STRIDES]]#1, %c1]
// STORE-ND-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32
// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 614f21bcebc48..4b710d3f51557 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -132,18 +132,10 @@ func.func @subgroup_load_nd_9(%src: memref<4x8x16xf16>) {
return
}
-// -----
-func.func @subgroup_load_nd_offset_1(%src: memref<4x8x16xf16>, %x : index) {
- %1 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<16xf16>
-// expected-error at +1 {{Offsets rank must match either the source or the TensorDesc rank.}}
- %2 = xegpu.load_nd %1[0, 0] : !xegpu.tensor_desc<16xf16> -> vector<16xf16>
- return
-}
-
// -----
func.func @subgroup_load_nd_offset_2(%src: memref<4x8x16xf16>, %x : index) {
%3 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
- // expected-error at +1 {{Offsets rank must match either the source or the TensorDesc rank.}}
+ // expected-error at +1 {{Offsets rank cannot be smaller than tensor descriptor rank.}}
xegpu.prefetch_nd %3[0] : !xegpu.tensor_desc<8x16xf16>
return
}
@@ -152,7 +144,7 @@ func.func @subgroup_load_nd_offset_2(%src: memref<4x8x16xf16>, %x : index) {
func.func @subgroup_load_nd_offset_3(%src: memref<4x8x16xf16>, %x : index) {
%3 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
%5 = xegpu.load_nd %3[0, 0] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
- // expected-error at +1 {{Offsets rank must match either the source or the TensorDesc rank.}}
+ // expected-error at +1 {{Offsets rank cannot be smaller than tensor descriptor rank.}}
xegpu.store_nd %5, %3[%x] : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16>
return
}
>From 0b26a417c5c4e5a33a319d70564385b332dba39b Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Tue, 21 Oct 2025 23:00:36 +0000
Subject: [PATCH 09/13] apply clang-format
Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 0f031be26cebc..11bf3152e5cc4 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -124,7 +124,8 @@ static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter,
}
auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, src);
- SmallVector<OpFoldResult> mixedStrides(meta.getStrides().begin(), meta.getStrides().end());
+ SmallVector<OpFoldResult> mixedStrides(meta.getStrides().begin(),
+ meta.getStrides().end());
ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
mixedShapes, mixedStrides);
}
>From 1f0e95384aa384278b240a38aee1ea40dc21c2d2 Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Wed, 22 Oct 2025 09:14:56 +0000
Subject: [PATCH 10/13] fix docs
Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
.../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 20 +++++++++----------
1 file changed, 10 insertions(+), 10 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 93c9f305c080c..489bd513a0bd4 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -262,9 +262,9 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
```
The operation may take optional offsets for the tensor descriptor.
- The number of offsets must be greater or equal to the rank of the tensor descriptor
- and less than the rank of the source memref. The offsets are applied to the innermost
- dimension of the source memref.
+ The number of offsets must be greater than or equal to the rank of the tensor
+ descriptor and less than or equal to the rank of the source memref.
+ The offsets are applied to the innermost dimensions of the source memref.
Examples:
```mlir
@@ -274,7 +274,7 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
// memref[0, %off0, %off1, %off2]
xegpu.prefetch_nd %tdesc[%off0, %off1, %off2] : !xegpu.tensor_desc<8x16xf16>
// memref[%off0, %off1, %off2, %off3]
- xegpu.prefetch_nd %tdesc[%off0, %off1, %off2] : !xegpu.tensor_desc<8x16xf16>
+ xegpu.prefetch_nd %tdesc[%off0, %off1, %off2, %off3] : !xegpu.tensor_desc<8x16xf16>
```
}];
@@ -366,9 +366,9 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
```
The operation may take optional offsets for the tensor descriptor.
- The number of offsets must be greater or equal to the rank of the tensor descriptor
- and less than the rank of the source memref. The offsets are applied to the innermost
- dimension of the source memref.
+ The number of offsets must be greater than or equal to the rank of the tensor
+ descriptor and less than or equal to the rank of the source memref.
+ The offsets are applied to the innermost dimensions of the source memref.
Examples:
```mlir
@@ -476,9 +476,9 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
```
The operation may take optional offsets for the tensor descriptor.
- The number of offsets must be greater or equal to the rank of the tensor descriptor
- and less than the rank of the source memref. The offsets are applied to the innermost
- dimension of the source memref.
+ The number of offsets must be greater than or equal to the rank of the tensor
+ descriptor and less than or equal to the rank of the source memref.
+ The offsets are applied to the innermost dimensions of the source memref.
Examples:
```mlir
>From 173eb6ddc4f716dd12e2d94d5e7494abba53ce97 Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Wed, 22 Oct 2025 12:03:51 +0000
Subject: [PATCH 11/13] use extractStridedMetadataOp to compute shapes for
tdesc
Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
.../Conversion/VectorToXeGPU/VectorToXeGPU.cpp | 18 ++----------------
.../VectorToXeGPU/load-to-xegpu.mlir | 10 ++--------
.../VectorToXeGPU/store-to-xegpu.mlir | 10 ++--------
.../VectorToXeGPU/transfer-read-to-xegpu.mlir | 13 +++----------
.../VectorToXeGPU/transfer-write-to-xegpu.mlir | 10 ++--------
5 files changed, 11 insertions(+), 50 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 11bf3152e5cc4..ee2e8a69edcc0 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -110,24 +110,10 @@ static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter,
} else {
// In case of any dynamic shapes, source's shape and strides have to be
// explicitly provided.
- SmallVector<Value> sourceDims;
- unsigned srcRank = srcTy.getRank();
- for (unsigned i = 0; i < srcRank; ++i)
- sourceDims.push_back(memref::DimOp::create(rewriter, loc, src, i));
-
- SmallVector<OpFoldResult> mixedShapes;
- for (auto [idx, shape] : llvm::enumerate(srcTy.getShape())) {
- if (shape == ShapedType::kDynamic)
- mixedShapes.push_back(sourceDims[idx]);
- else
- mixedShapes.push_back(rewriter.getI64IntegerAttr(shape));
- }
-
auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, src);
- SmallVector<OpFoldResult> mixedStrides(meta.getStrides().begin(),
- meta.getStrides().end());
ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
- mixedShapes, mixedStrides);
+ meta.getConstifiedMixedSizes(),
+ meta.getConstifiedMixedStrides());
}
return ndDesc;
diff --git a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
index 1975c96bfe796..a3ed559f6413d 100644
--- a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
@@ -46,15 +46,9 @@ func.func @load_dynamic_source(%source: memref<?x?x?xf32>,
// CHECK-LABEL: @load_dynamic_source(
// CHECK-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
-// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
-// CHECK-DAG: %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
-// CHECK-DAG: %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
-// CHECK-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
-// CHECK: {{.*}} %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]]
+// CHECK: {{.*}} %[[SIZES:.+]]:3, %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]]
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
-// CHECK-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[STRIDES]]#0, %[[STRIDES]]#1, %c1]
+// CHECK-SAME: , shape : [%[[SIZES]]#0, %[[SIZES]]#1, %[[SIZES]]#2], strides : [%[[STRIDES]]#0, %[[STRIDES]]#1, 1]
// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
// CHECK: return %[[VEC]]
diff --git a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
index 63e78ca20bcee..573e35de7b42e 100644
--- a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
@@ -48,15 +48,9 @@ func.func @store_dynamic_source(%vec: vector<8x16xf32>,
// CHECK-SAME: %[[VEC:.+]]: vector<8x16xf32>,
// CHECK-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
-// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
-// CHECK-DAG: %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
-// CHECK-DAG: %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
-// CHECK-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
-// CHECK: {{.*}} %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]]
+// CHECK: {{.*}} %[[SIZES:.+]]:3, %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]]
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
-// CHECK-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[STRIDES]]#0, %[[STRIDES]]#1, %c1]
+// CHECK-SAME: , shape : [%[[SIZES]]#0, %[[SIZES]]#1, %[[SIZES]]#2], strides : [%[[STRIDES]]#0, %[[STRIDES]]#1, 1]
// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
index 81527a8111bb0..1b0f492372eef 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
@@ -144,13 +144,7 @@ gpu.func @load_dynamic_source(%source: memref<?x?x?xf32>,
// LOAD-ND-LABEL: @load_dynamic_source(
// LOAD-ND-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
// LOAD-ND-SAME: %[[OFFSET:.+]]: index
-// LOAD-ND: %[[C2:.+]] = arith.constant 2 : index
-// LOAD-ND: %[[C1:.+]] = arith.constant 1 : index
-// LOAD-ND: %[[C0:.+]] = arith.constant 0 : index
-// LOAD-ND-DAG: %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
-// LOAD-ND-DAG: %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
-// LOAD-ND-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
-// LOAD-ND: {{.*}} %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]]
+// LOAD-ND: {{.*}} %[[SIZES:.+]]:3, %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]]
// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET:.+]], %[[OFFSET:.+]], %[[OFFSET:.+]]]{{.*}}-> vector<8x16xf32>
// LOAD-ND: return %[[VEC]]
@@ -184,9 +178,8 @@ gpu.func @load_dynamic_source2(%source: memref<?x8x16xf32>,
}
// LOAD-ND-LABEL: @load_dynamic_source2(
-// LOAD-ND-DAG: %[[C0:.+]] = arith.constant 0 : index
-// LOAD-ND-DAG: %[[DIM:.+]] = memref.dim %{{.*}}, %[[C0]] : memref<?x8x16xf32>
-// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %{{.*}}, shape : [%[[DIM]], 8, 16], strides : [%c128, %c16, %c1] : memref<?x8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
+// LOAD-ND-DAG: {{.*}} %[[SIZES:.+]]:3, %[[STRIDES:.+]]:3 = memref.extract_strided_metadata
+// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %{{.*}}, shape : [%[[SIZES]]#0, 8, 16], strides : [128, 16, 1] : memref<?x8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%{{.*}}, %{{.*}}, %{{.*}}] : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>> -> vector<8x16xf32>
// LOAD-ND: return %[[VEC]] : vector<8x16xf32>
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
index 83d33e1905f7c..8ca86c39d640d 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
@@ -84,15 +84,9 @@ gpu.func @store_dynamic_source(%vec: vector<8x16xf32>,
// STORE-ND-SAME: %[[VEC:.+]]: vector<8x16xf32>,
// STORE-ND-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
// STORE-ND-SAME: %[[OFFSET:.+]]: index
-// STORE-ND-DAG: %[[C0:.+]] = arith.constant 0 : index
-// STORE-ND-DAG: %[[C1:.+]] = arith.constant 1 : index
-// STORE-ND-DAG: %[[C2:.+]] = arith.constant 2 : index
-// STORE-ND-DAG: %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
-// STORE-ND-DAG: %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
-// STORE-ND-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
-// STORE-ND: {{.*}} %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]]
+// STORE-ND: {{.*}} %[[SIZES:.+]]:3, %[[STRIDES:.+]]:3 = memref.extract_strided_metadata
// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
-// STORE-ND-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[STRIDES]]#0, %[[STRIDES]]#1, %c1]
+// STORE-ND-SAME: , shape : [%[[SIZES]]#0, %[[SIZES]]#1, %[[SIZES]]#2], strides : [%[[STRIDES]]#0, %[[STRIDES]]#1, 1]
// STORE-ND-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32
// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
>From babf57e4fded2d1973885aa431bc26d77b175f15 Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Thu, 30 Oct 2025 11:16:25 +0000
Subject: [PATCH 12/13] revert xegpu def changes
Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
.../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 45 -------------
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 67 ++++++++++---------
mlir/test/Dialect/XeGPU/invalid.mlir | 12 +++-
3 files changed, 44 insertions(+), 80 deletions(-)
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 489bd513a0bd4..426377fcf598f 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -261,21 +261,6 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
: !xegpu.tensor_desc<8x16xf16>
```
- The operation may take optional offsets for the tensor descriptor.
- The number of offsets must be greater than or equal to the rank of the tensor
- descriptor and less than or equal to the rank of the source memref.
- The offsets are applied to the innermost dimensions of the source memref.
-
- Examples:
- ```mlir
- %tdesc = xegpu.create_nd_tdesc %0: memref<2x8x32x32xf32> -> TensorDesc<8x16xf32>
- // memref[0, 0, %off0, %off1]
- xegpu.prefetch_nd %tdesc[%off0, %off1] : !xegpu.tensor_desc<8x16xf16>
- // memref[0, %off0, %off1, %off2]
- xegpu.prefetch_nd %tdesc[%off0, %off1, %off2] : !xegpu.tensor_desc<8x16xf16>
- // memref[%off0, %off1, %off2, %off3]
- xegpu.prefetch_nd %tdesc[%off0, %off1, %off2, %off3] : !xegpu.tensor_desc<8x16xf16>
- ```
}];
let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
@@ -365,21 +350,6 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
: !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
```
- The operation may take optional offsets for the tensor descriptor.
- The number of offsets must be greater than or equal to the rank of the tensor
- descriptor and less than or equal to the rank of the source memref.
- The offsets are applied to the innermost dimensions of the source memref.
-
- Examples:
- ```mlir
- %1 = xegpu.create_nd_tdesc %0: memref<2x8x32x32xf32> -> TensorDesc<8x16xf32>
- // memref[0, 0, %off0, %off1]
- xegpu.load_nd %1[%off0, %off1] : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
- // memref[0, %off0, %off1, %off2]
- xegpu.load_nd %1[%off0, %off1, %off2] : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
- // memref[%off0, %off1, %off2, %off3]
- xegpu.load_nd %1[%off0, %off1, %off2, %off3] : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
- ```
}];
@@ -475,21 +445,6 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
: vector<8xf16>, !xegpu.tensor_desc<8x16xf16>
```
- The operation may take optional offsets for the tensor descriptor.
- The number of offsets must be greater than or equal to the rank of the tensor
- descriptor and less than or equal to the rank of the source memref.
- The offsets are applied to the innermost dimensions of the source memref.
-
- Examples:
- ```mlir
- %2 = xegpu.create_nd_tdesc %0: memref<2x8x32x32xf32> -> TensorDesc<8x16xf32>
- // memref[0, 0, %off0, %off1]
- xegpu.store_nd %3, %2[%off0, %off1] : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16>
- // memref[0, %off0, %off1, %off2]
- xegpu.store_nd %3, %2[%off0, %off1, %off2] : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16>
- // memref[%off0, %off1, %off2, %off3]
- xegpu.store_nd %3, %2[%off0, %off1, %off2, %off3] : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16>
- ```
}];
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 76640bb59be46..abd12e2e69ac0 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -121,22 +121,6 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
return success();
}
-// Verify that number of offsets matches either the source rank or the tdesc
-// rank.
-static LogicalResult
-isValidNdOffset(TypedValue<TensorDescType> tDesc,
- std::optional<llvm::ArrayRef<int64_t>> constOffsets,
- int64_t offsetSize,
- function_ref<InFlightDiagnostic()> emitError) {
- int64_t constOffsetSize = constOffsets ? constOffsets->size() : 0;
- auto tDescRank = tDesc.getType().getRank();
- if (((offsetSize != 0) && (offsetSize < tDescRank)) ||
- ((constOffsetSize != 0) && (constOffsetSize < tDescRank)))
- return emitError() << "Offsets rank cannot be smaller than tensor "
- "descriptor rank.";
- return success();
-}
-
static LogicalResult
isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy,
VectorType valueTy, int64_t chunkSize,
@@ -274,10 +258,8 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
// if shape and strides are from Memref, we don't need attributes for them
- // to keep the IR print clean (only do so for full-static case, otherwise
- // printer would fail trying to print empty array-attr).
- if (staticShape == memrefShape && staticStrides == memrefStrides &&
- dynamicShape.empty() && dynamicStrides.empty()) {
+ // to keep the IR print clean.
+ if (staticShape == memrefShape && staticStrides == memrefStrides) {
staticShapeAttr = DenseI64ArrayAttr();
staticStridesAttr = DenseI64ArrayAttr();
}
@@ -338,10 +320,8 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
// if shape and strides are from Memref, we don't need attributes for them
- // to keep the IR print clean (only do so for full-static case, otherwise
- // printer would fail trying to print empty array-attr).
- if (staticShape == memrefShape && staticStrides == memrefStrides &&
- dynamicShape.empty() && dynamicStrides.empty()) {
+ // to keep the IR print clean.
+ if (staticShape == memrefShape && staticStrides == memrefStrides) {
staticShapeAttr = DenseI64ArrayAttr();
staticStridesAttr = DenseI64ArrayAttr();
}
@@ -491,9 +471,16 @@ LogicalResult PrefetchNdOp::verify() {
if (!isReadHintOrNone(getL3HintAttr()))
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
- auto tDesc = getTensorDesc();
- return isValidNdOffset(tDesc, getConstOffsets(), getMixedOffsets().size(),
- [&]() { return emitOpError(); });
+ int64_t tDescRank = tdescTy.getRank();
+ int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
+ int64_t constOffsetSize =
+ getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
+ if (((offsetSize != 0) && (offsetSize != tDescRank)) ||
+ ((constOffsetSize != 0) && (constOffsetSize != tDescRank)))
+ return emitOpError(
+ "Mismatched ranks between offsets and tensor descriptor");
+
+ return success();
}
//===----------------------------------------------------------------------===//
@@ -609,9 +596,16 @@ LogicalResult LoadNdOp::verify() {
<< " is not consistent with tensor descriptor "
<< tdescTy;
- auto tDesc = getTensorDesc();
- return isValidNdOffset(tDesc, getConstOffsets(), getMixedOffsets().size(),
- [&]() { return emitOpError(); });
+ int64_t tDescRank = tdescTy.getRank();
+ int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
+ int64_t constOffsetSize =
+ getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
+ if (((offsetSize != 0) && (offsetSize != tDescRank)) ||
+ ((constOffsetSize != 0) && (constOffsetSize != tDescRank)))
+ return emitOpError(
+ "Mismatched ranks between offsets and tensor descriptor");
+
+ return success();
}
//===----------------------------------------------------------------------===//
@@ -696,9 +690,16 @@ LogicalResult StoreNdOp::verify() {
<< " is not consistent with tensor descriptor "
<< dstTy;
- auto tDesc = getTensorDesc();
- return isValidNdOffset(tDesc, getConstOffsets(), getMixedOffsets().size(),
- [&]() { return emitOpError(); });
+ int64_t tDescRank = dstTy.getRank();
+ int64_t offsetSize = static_cast<int64_t>(getOffsets().size());
+ int64_t constOffsetSize =
+ getConstOffsetsAttr() ? getConstOffsetsAttr().size() : 0;
+ if (((offsetSize != 0) && (offsetSize != tDescRank)) ||
+ ((constOffsetSize != 0) && (constOffsetSize != tDescRank)))
+ return emitOpError(
+ "Mismatched ranks between offsets and tensor descriptor");
+
+ return success();
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 4b710d3f51557..ebbe3ce0ec0d0 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -132,10 +132,18 @@ func.func @subgroup_load_nd_9(%src: memref<4x8x16xf16>) {
return
}
+// -----
+func.func @subgroup_load_nd_offset_1(%src: memref<4x8x16xf16>, %x : index) {
+ %1 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<16xf16>
+// expected-error at +1 {{Mismatched ranks between offsets and tensor descriptor}}
+ %2 = xegpu.load_nd %1[0, 0] : !xegpu.tensor_desc<16xf16> -> vector<16xf16>
+ return
+}
+
// -----
func.func @subgroup_load_nd_offset_2(%src: memref<4x8x16xf16>, %x : index) {
%3 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
- // expected-error at +1 {{Offsets rank cannot be smaller than tensor descriptor rank.}}
+ // expected-error at +1 {{Mismatched ranks between offsets and tensor descriptor}}
xegpu.prefetch_nd %3[0] : !xegpu.tensor_desc<8x16xf16>
return
}
@@ -144,7 +152,7 @@ func.func @subgroup_load_nd_offset_2(%src: memref<4x8x16xf16>, %x : index) {
func.func @subgroup_load_nd_offset_3(%src: memref<4x8x16xf16>, %x : index) {
%3 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
%5 = xegpu.load_nd %3[0, 0] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
- // expected-error at +1 {{Offsets rank cannot be smaller than tensor descriptor rank.}}
+ // expected-error at +1 {{Mismatched ranks between offsets and tensor descriptor}}
xegpu.store_nd %5, %3[%x] : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16>
return
}
>From 2a38c2cc7733c3ba934a66e5a198db864edd9c49 Mon Sep 17 00:00:00 2001
From: dchigarev <dmitry.chigarev at intel.com>
Date: Thu, 30 Oct 2025 14:32:34 +0000
Subject: [PATCH 13/13] collapse memref shape to 2d
Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
---
.../VectorToXeGPU/VectorToXeGPU.cpp | 128 ++++++++++++++----
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 12 +-
.../VectorToXeGPU/load-to-xegpu.mlir | 29 ++--
.../VectorToXeGPU/store-to-xegpu.mlir | 29 ++--
.../VectorToXeGPU/transfer-read-to-xegpu.mlir | 38 +++---
.../transfer-write-to-xegpu.mlir | 41 +++---
6 files changed, 179 insertions(+), 98 deletions(-)
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index ee2e8a69edcc0..34c302b4968c5 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -358,6 +358,63 @@ static Value computeOffsets(PatternRewriter &rewriter, OpType gatScatOp,
.getResult();
}
+// Collapses shapes of a nD memref to the target rank while applying offsets for
+// the collapsed dimensions. Returns the new memref value and the remaining
+// offsets for the last targetRank dimensions. For example:
+// input: %memref = memref<2x4x8x32xf32>, offsets=[%i0, %i1, %i2, %i3],
+// targetRank=2 output: %memref[%i0, %i1, 0, 0] -> memref<8x32xf32>, returned
+// offsets: [%i2, %i3]
+static std::pair<Value, SmallVector<OpFoldResult>>
+convertMemrefAndOffsetsToTargetRank(PatternRewriter &rewriter, Location loc,
+ Value memref,
+ SmallVector<OpFoldResult> offsets,
+ int64_t targetRank) {
+ auto memrefType = cast<MemRefType>(memref.getType());
+ unsigned rank = memrefType.getRank();
+
+ if (rank <= targetRank)
+ return {memref, offsets};
+
+ int64_t numCombinedDims = rank - targetRank;
+ SmallVector<OpFoldResult> subviewOffsets;
+ SmallVector<OpFoldResult> subviewSizes;
+ SmallVector<OpFoldResult> subviewStrides;
+
+ // For the combined dimensions: use the provided offsets, size=1, stride=1
+ for (unsigned i = 0; i < numCombinedDims; ++i) {
+ subviewOffsets.push_back(offsets[i]);
+ subviewSizes.push_back(rewriter.getI64IntegerAttr(1));
+ subviewStrides.push_back(rewriter.getI64IntegerAttr(1));
+ }
+
+ // For the last targetRank dimensions: offset=0, use full size, stride=1
+ SmallVector<int64_t> resultShape;
+ auto originalShape = memrefType.getShape();
+ auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, memref);
+ for (unsigned i = numCombinedDims; i < rank; ++i) {
+ subviewOffsets.push_back(rewriter.getI64IntegerAttr(0));
+ if (ShapedType::isDynamic(originalShape[i])) {
+ subviewSizes.push_back(meta.getSizes()[i]);
+ resultShape.push_back(ShapedType::kDynamic);
+ } else {
+ subviewSizes.push_back(rewriter.getI64IntegerAttr(originalShape[i]));
+ resultShape.push_back(originalShape[i]);
+ }
+ subviewStrides.push_back(rewriter.getI64IntegerAttr(1));
+ }
+
+ auto resultType = memref::SubViewOp::inferRankReducedResultType(
+ resultShape, memrefType, subviewOffsets, subviewSizes, subviewStrides);
+ auto subviewOp =
+ memref::SubViewOp::create(rewriter, loc, resultType, memref,
+ subviewOffsets, subviewSizes, subviewStrides);
+
+ // Return the remaining offsets for the last targetRank dimensions
+ SmallVector<OpFoldResult> newOffsets(offsets.begin() + numCombinedDims,
+ offsets.end());
+ return {subviewOp.getResult(), newOffsets};
+}
+
template <
typename OpType,
typename = std::enable_if_t<llvm::is_one_of<
@@ -493,17 +550,18 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
!isTransposeLoad ? nullptr
: DenseI64ArrayAttr::get(rewriter.getContext(),
ArrayRef<int64_t>{1, 0});
+ auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
+ rewriter, loc, readOp.getBase(), getAsOpFoldResult(readOp.getIndices()),
+ vecTy.getRank());
// By default, no specific caching policy is assigned.
xegpu::CachePolicyAttr hint = nullptr;
- xegpu::CreateNdDescOp ndDesc =
- createNdDescriptor(rewriter, loc, descType,
- dyn_cast<TypedValue<MemRefType>>(readOp.getBase()));
-
- auto loadOp = xegpu::LoadNdOp::create(
- rewriter, loc, vecTy, ndDesc, getAsOpFoldResult(readOp.getIndices()),
- /*packed=*/nullptr, transposeAttr,
- /*l1_hint=*/hint,
- /*l2_hint=*/hint, /*l3_hint=*/hint);
+ xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
+ rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
+
+ auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, indices,
+ /*packed=*/nullptr, transposeAttr,
+ /*l1_hint=*/hint,
+ /*l2_hint=*/hint, /*l3_hint=*/hint);
rewriter.replaceOp(readOp, loadOp);
return success();
@@ -541,21 +599,23 @@ struct TransferWriteLowering
if (!map.isMinorIdentity())
return rewriter.notifyMatchFailure(writeOp, "Expects identity map");
+ auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
+ rewriter, loc, writeOp.getBase(),
+ getAsOpFoldResult(writeOp.getIndices()), vecTy.getRank());
+
auto descType = xegpu::TensorDescType::get(
vecTy.getShape(), vecTy.getElementType(),
/*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(),
xegpu::MemorySpace::Global);
// By default, no specific caching policy is assigned.
xegpu::CachePolicyAttr hint = nullptr;
- xegpu::CreateNdDescOp ndDesc =
- createNdDescriptor(rewriter, loc, descType,
- dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()));
+ xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
+ rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
- auto storeOp =
- xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc,
- getAsOpFoldResult(writeOp.getIndices()),
- /*l1_hint=*/hint,
- /*l2_hint=*/hint, /*l3_hint=*/hint);
+ auto storeOp = xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(),
+ ndDesc, indices,
+ /*l1_hint=*/hint,
+ /*l2_hint=*/hint, /*l3_hint=*/hint);
rewriter.replaceOp(writeOp, storeOp);
return success();
@@ -643,17 +703,21 @@ struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
// By default, no specific caching policy is assigned.
xegpu::CachePolicyAttr hint = nullptr;
+ auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
+ rewriter, loc, loadOp.getBase(), getAsOpFoldResult(loadOp.getIndices()),
+ vecTy.getRank());
+
auto descType = xegpu::TensorDescType::get(
vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1,
boundaryCheck, xegpu::MemorySpace::Global);
- xegpu::CreateNdDescOp ndDesc =
- createNdDescriptor(rewriter, loc, descType, loadOp.getBase());
- auto loadNdOp = xegpu::LoadNdOp::create(
- rewriter, loc, vecTy, ndDesc, getAsOpFoldResult(loadOp.getIndices()),
- /*packed=*/nullptr, /*transpose=*/nullptr,
- /*l1_hint=*/hint,
- /*l2_hint=*/hint, /*l3_hint=*/hint);
+ xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
+ rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
+ auto loadNdOp =
+ xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, indices,
+ /*packed=*/nullptr, /*transpose=*/nullptr,
+ /*l1_hint=*/hint,
+ /*l2_hint=*/hint, /*l3_hint=*/hint);
rewriter.replaceOp(loadOp, loadNdOp);
return success();
@@ -675,19 +739,23 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
// Boundary check is available only for block instructions.
bool boundaryCheck = vecTy.getRank() > 1;
+ auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
+ rewriter, loc, storeOp.getBase(),
+ getAsOpFoldResult(storeOp.getIndices()), vecTy.getRank());
+
auto descType = xegpu::TensorDescType::get(
vecTy.getShape(), vecTy.getElementType(),
/*array_length=*/1, boundaryCheck, xegpu::MemorySpace::Global);
// By default, no specific caching policy is assigned.
xegpu::CachePolicyAttr hint = nullptr;
- xegpu::CreateNdDescOp ndDesc =
- createNdDescriptor(rewriter, loc, descType, storeOp.getBase());
+ xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
+ rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
- auto storeNdOp = xegpu::StoreNdOp::create(
- rewriter, loc, vector, ndDesc, getAsOpFoldResult(storeOp.getIndices()),
- /*l1_hint=*/hint,
- /*l2_hint=*/hint, /*l3_hint=*/hint);
+ auto storeNdOp =
+ xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc, indices,
+ /*l1_hint=*/hint,
+ /*l2_hint=*/hint, /*l3_hint=*/hint);
rewriter.replaceOp(storeOp, storeNdOp);
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index abd12e2e69ac0..8ed8b26dd2a0e 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -258,8 +258,10 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
// if shape and strides are from Memref, we don't need attributes for them
- // to keep the IR print clean.
- if (staticShape == memrefShape && staticStrides == memrefStrides) {
+ // to keep the IR print clean (only do so for full-static case, otherwise
+ // printer would fail trying to print empty array-attr).
+ if (staticShape == memrefShape && staticStrides == memrefStrides &&
+ dynamicShape.empty() && dynamicStrides.empty()) {
staticShapeAttr = DenseI64ArrayAttr();
staticStridesAttr = DenseI64ArrayAttr();
}
@@ -320,8 +322,10 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
// if shape and strides are from Memref, we don't need attributes for them
- // to keep the IR print clean.
- if (staticShape == memrefShape && staticStrides == memrefStrides) {
+ // to keep the IR print clean (only do so for full-static case, otherwise
+ // printer would fail trying to print empty array-attr).
+ if (staticShape == memrefShape && staticStrides == memrefStrides &&
+ dynamicShape.empty() && dynamicStrides.empty()) {
staticShapeAttr = DenseI64ArrayAttr();
staticStridesAttr = DenseI64ArrayAttr();
}
diff --git a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
index a3ed559f6413d..ae5141db16c09 100644
--- a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
@@ -9,11 +9,12 @@ func.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vecto
// CHECK-LABEL: @load_1D_vector(
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
+// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0]
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME: %[[SRC]]
-// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
+// CHECK-SAME: %[[COLLAPSED]]
+// CHECK-SAME: memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32,
// CHECK-SAME: boundary_check = false
-// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8xf32>
+// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]]]{{.*}}-> vector<8xf32>
// CHECK: return %[[VEC]]
// -----
@@ -28,29 +29,29 @@ func.func @load_2D_vector(%source: memref<8x16x32xf32>,
// CHECK-LABEL: @load_2D_vector(
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
+// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME: %[[SRC]]
-// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
+// CHECK-SAME: %[[COLLAPSED]]
+// CHECK-SAME: memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
// CHECK: return %[[VEC]]
// -----
func.func @load_dynamic_source(%source: memref<?x?x?xf32>,
- %offset: index) -> vector<8x16xf32> {
- %0 = vector.load %source[%offset, %offset, %offset]
+ %i: index, %j: index, %k: index) -> vector<8x16xf32> {
+ %0 = vector.load %source[%i, %j, %k]
: memref<?x?x?xf32>, vector<8x16xf32>
return %0 : vector<8x16xf32>
}
// CHECK-LABEL: @load_dynamic_source(
// CHECK-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
-// CHECK-SAME: %[[OFFSET:.+]]: index
-// CHECK: {{.*}} %[[SIZES:.+]]:3, %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]]
-// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
-// CHECK-SAME: , shape : [%[[SIZES]]#0, %[[SIZES]]#1, %[[SIZES]]#2], strides : [%[[STRIDES]]#0, %[[STRIDES]]#1, 1]
-// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
+// CHECK-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
+// CHECK: {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
+// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF1]], %[[OFF2]]]{{.*}}-> vector<8x16xf32>
// CHECK: return %[[VEC]]
// -----
diff --git a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
index 573e35de7b42e..1a10d917623cc 100644
--- a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
@@ -11,11 +11,12 @@ func.func @store_1D_vector(%vec: vector<8xf32>,
// CHECK-SAME: %[[VEC:.+]]: vector<8xf32>,
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
+// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0]
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME: %[[SRC]]
-// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
+// CHECK-SAME: %[[COLLAPSED]]
+// CHECK-SAME: memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32,
// CHECK-SAME: boundary_check = false
-// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8xf32>
+// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]]] : vector<8xf32>
// -----
@@ -30,16 +31,17 @@ func.func @store_2D_vector(%vec: vector<8x16xf32>,
// CHECK-SAME: %[[VEC:.+]]: vector<8x16xf32>,
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
+// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME: %[[SRC]]
-// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
+// CHECK-SAME: %[[COLLAPSED]]
+// CHECK-SAME: memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
// -----
func.func @store_dynamic_source(%vec: vector<8x16xf32>,
- %source: memref<?x?x?xf32>, %offset: index) {
- vector.store %vec, %source[%offset, %offset, %offset]
+ %source: memref<?x?x?xf32>, %i: index, %j: index, %k: index) {
+ vector.store %vec, %source[%i, %j, %k]
: memref<?x?x?xf32>, vector<8x16xf32>
return
}
@@ -47,12 +49,11 @@ func.func @store_dynamic_source(%vec: vector<8x16xf32>,
// CHECK-LABEL: @store_dynamic_source(
// CHECK-SAME: %[[VEC:.+]]: vector<8x16xf32>,
// CHECK-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
-// CHECK-SAME: %[[OFFSET:.+]]: index
-// CHECK: {{.*}} %[[SIZES:.+]]:3, %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]]
-// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
-// CHECK-SAME: , shape : [%[[SIZES]]#0, %[[SIZES]]#1, %[[SIZES]]#2], strides : [%[[STRIDES]]#0, %[[STRIDES]]#1, 1]
-// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
+// CHECK-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
+// CHECK: {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
+// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFF1]], %[[OFF2]]] : vector<8x16xf32>
// -----
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
index 1b0f492372eef..c87a5304babfe 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
@@ -12,11 +12,12 @@ gpu.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vector
// LOAD-ND-LABEL: @load_1D_vector(
// LOAD-ND-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// LOAD-ND-SAME: %[[OFFSET:.+]]: index
+// LOAD-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0]
// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// LOAD-ND-SAME: %[[SRC]]
-// LOAD-ND-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
+// LOAD-ND-SAME: %[[COLLAPSED]]
+// LOAD-ND-SAME: memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32,
// LOAD-ND-SAME: boundary_check = false
-// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8xf32>
+// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]]]{{.*}}-> vector<8xf32>
// LOAD-ND: return %[[VEC]]
// LOAD-GATHER-LABEL: @load_1D_vector(
@@ -46,11 +47,12 @@ gpu.func @load_2D_vector(%source: memref<8x16x32xf32>,
// LOAD-ND-LABEL: @load_2D_vector(
// LOAD-ND-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// LOAD-ND-SAME: %[[OFFSET:.+]]: index
+// LOAD-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// LOAD-ND-SAME: %[[SRC]]
-// LOAD-ND-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
+// LOAD-ND-SAME: %[[COLLAPSED]]
+// LOAD-ND-SAME: memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32,
// LOAD-ND-SAME: boundary_check = false
-// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
+// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
// LOAD-ND: return %[[VEC]]
// LOAD-GATHER-LABEL: @load_2D_vector(
@@ -143,10 +145,11 @@ gpu.func @load_dynamic_source(%source: memref<?x?x?xf32>,
}
// LOAD-ND-LABEL: @load_dynamic_source(
// LOAD-ND-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
-// LOAD-ND-SAME: %[[OFFSET:.+]]: index
-// LOAD-ND: {{.*}} %[[SIZES:.+]]:3, %[[STRIDES:.+]]:3 = memref.extract_strided_metadata %[[SRC]]
-// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
-// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET:.+]], %[[OFFSET:.+]], %[[OFFSET:.+]]]{{.*}}-> vector<8x16xf32>
+// LOAD-ND-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// LOAD-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
+// LOAD-ND: {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
+// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF1]], %[[OFF2]]]{{.*}}-> vector<8x16xf32>
// LOAD-ND: return %[[VEC]]
@@ -178,9 +181,11 @@ gpu.func @load_dynamic_source2(%source: memref<?x8x16xf32>,
}
// LOAD-ND-LABEL: @load_dynamic_source2(
-// LOAD-ND-DAG: {{.*}} %[[SIZES:.+]]:3, %[[STRIDES:.+]]:3 = memref.extract_strided_metadata
-// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %{{.*}}, shape : [%[[SIZES]]#0, 8, 16], strides : [128, 16, 1] : memref<?x8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
-// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%{{.*}}, %{{.*}}, %{{.*}}] : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>> -> vector<8x16xf32>
+// LOAD-ND-SAME: %[[SRC:.+]]: memref<?x8x16xf32>,
+// LOAD-ND-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// LOAD-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
+// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32, strided<[16, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
+// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%{{.*}}, %{{.*}}] : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>> -> vector<8x16xf32>
// LOAD-ND: return %[[VEC]] : vector<8x16xf32>
// LOAD-GATHER-LABEL: @load_dynamic_source2(
@@ -411,11 +416,12 @@ gpu.func @load_from_subview(%source: memref<4096x4096xf16>, %off1: index, %off2:
// LOAD-ND-SAME: %[[SRC:.+]]: memref<4096x4096xf16>,
// LOAD-ND-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
// LOAD-ND: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
+// LOAD-ND: %[[COLLAPSED:.+]] = memref.subview %[[SUBVIEW]][%[[OFF2]], 0]
// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// LOAD-ND-SAME: %[[SUBVIEW]]
-// LOAD-ND-SAME: memref<256x256xf16, strided<[4096, 1], offset: ?>> -> !xegpu.tensor_desc<8xf16,
+// LOAD-ND-SAME: %[[COLLAPSED]]
+// LOAD-ND-SAME: memref<256xf16, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf16,
// LOAD-ND-SAME: boundary_check = false
-// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF2]], %[[OFF2]]]{{.*}}-> vector<8xf16>
+// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF2]]]{{.*}}-> vector<8xf16>
// LOAD-ND: return %[[VEC]]
// LOAD-GATHER-LABEL: @load_from_subview(
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
index 8ca86c39d640d..43a1a7206e2cc 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --xevm-attach-target='module=xevm.* O=3 chip=pvc' -convert-vector-to-xegpu -split-input-file | FileCheck %s --check-prefix=STORE-ND
+// RUN: mlir-opt %s --xevm-attach-target='module=xevm_* O=3 chip=pvc' -convert-vector-to-xegpu -split-input-file | FileCheck %s --check-prefix=STORE-ND
// RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s --check-prefix=STORE-SCATTER
@@ -15,11 +15,12 @@ gpu.func @store_1D_vector(%vec: vector<8xf32>,
// STORE-ND-SAME: %[[VEC:.+]]: vector<8xf32>,
// STORE-ND-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// STORE-ND-SAME: %[[OFFSET:.+]]: index
+// STORE-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0]
// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// STORE-ND-SAME: %[[SRC]]
-// STORE-ND-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
+// STORE-ND-SAME: %[[COLLAPSED]]
+// STORE-ND-SAME: memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32,
// STORE-ND-SAME: boundary_check = false
-// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8xf32>
+// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]]] : vector<8xf32>
// STORE-SCATTER-LABEL: @store_1D_vector(
// STORE-SCATTER-SAME: %[[VEC:.+]]: vector<8xf32>,
@@ -49,11 +50,12 @@ gpu.func @store_2D_vector(%vec: vector<8x16xf32>,
// STORE-ND-SAME: %[[VEC:.+]]: vector<8x16xf32>,
// STORE-ND-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// STORE-ND-SAME: %[[OFFSET:.+]]: index
+// STORE-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// STORE-ND-SAME: %[[SRC]]
-// STORE-ND-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
+// STORE-ND-SAME: %[[COLLAPSED]]
+// STORE-ND-SAME: memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32,
// STORE-ND-SAME: boundary_check = false
-// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
+// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
// STORE-SCATTER-LABEL: @store_2D_vector(
// STORE-SCATTER-SAME: %[[VEC:.+]]: vector<8x16xf32>,
@@ -73,8 +75,8 @@ gpu.func @store_2D_vector(%vec: vector<8x16xf32>,
// -----
gpu.module @xevm_module {
gpu.func @store_dynamic_source(%vec: vector<8x16xf32>,
- %source: memref<?x?x?xf32>, %offset: index) {
- vector.transfer_write %vec, %source[%offset, %offset, %offset]
+ %source: memref<?x?x?xf32>, %i: index, %j: index, %k: index) {
+ vector.transfer_write %vec, %source[%i, %j, %k]
{in_bounds = [true, true]}
: vector<8x16xf32>, memref<?x?x?xf32>
gpu.return
@@ -83,12 +85,11 @@ gpu.func @store_dynamic_source(%vec: vector<8x16xf32>,
// STORE-ND-LABEL: @store_dynamic_source(
// STORE-ND-SAME: %[[VEC:.+]]: vector<8x16xf32>,
// STORE-ND-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
-// STORE-ND-SAME: %[[OFFSET:.+]]: index
-// STORE-ND: {{.*}} %[[SIZES:.+]]:3, %[[STRIDES:.+]]:3 = memref.extract_strided_metadata
-// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
-// STORE-ND-SAME: , shape : [%[[SIZES]]#0, %[[SIZES]]#1, %[[SIZES]]#2], strides : [%[[STRIDES]]#0, %[[STRIDES]]#1, 1]
-// STORE-ND-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32
-// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
+// STORE-ND-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// STORE-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
+// STORE-ND: {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
+// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFF1]], %[[OFF2]]] : vector<8x16xf32>
// STORE-SCATTER-LABEL: @store_dynamic_source(
// STORE-SCATTER-SAME: %[[VEC:.+]]: vector<8x16xf32>,
@@ -292,13 +293,13 @@ gpu.func @store_to_subview(%vec: vector<8xf16>,
// STORE-ND-SAME: %[[VEC:.+]]: vector<8xf16>,
// STORE-ND-SAME: %[[SRC:.+]]: memref<4096x4096xf16>,
// STORE-ND-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
-// STORE-ND: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1]
-// STORE-ND-SAME: : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
+// STORE-ND: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
+// STORE-ND: %[[COLLAPSED:.+]] = memref.subview %[[SUBVIEW]][%[[OFF2]], 0]
// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// STORE-ND-SAME: %[[SUBVIEW]]
-// STORE-ND-SAME: memref<256x256xf16, strided<[4096, 1], offset: ?>> -> !xegpu.tensor_desc<8xf16,
+// STORE-ND-SAME: %[[COLLAPSED]]
+// STORE-ND-SAME: memref<256xf16, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf16,
// STORE-ND-SAME: boundary_check = false
-// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFF2]], %[[OFF2]]] : vector<8xf16>
+// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFF2]]] : vector<8xf16>
// STORE-SCATTER-LABEL: @store_to_subview(
// STORE-SCATTER-SAME: %[[VEC:.+]]: vector<8xf16>,
More information about the Mlir-commits
mailing list