[Mlir-commits] [mlir] 747050b - [MLIR][XeGPU][VectorToXeGPU] Lower vector.load/store/transfer_read/transfer_write to new offsets syntax (#162095)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Nov 4 04:52:28 PST 2025
Author: Dmitry Chigarev
Date: 2025-11-04T13:52:23+01:00
New Revision: 747050bcceca18d32dc1140461984ec2c30ae96a
URL: https://github.com/llvm/llvm-project/commit/747050bcceca18d32dc1140461984ec2c30ae96a
DIFF: https://github.com/llvm/llvm-project/commit/747050bcceca18d32dc1140461984ec2c30ae96a.diff
LOG: [MLIR][XeGPU][VectorToXeGPU] Lower vector.load/store/transfer_read/transfer_write to new offsets syntax (#162095)
Changes the `VectorToXeGPU` pass to generate `xegpu.load_nd/store_nd`
ops using new syntax with where offsets are specified at the load/store
ops level.
```mlir
// from this
%desc = xegpu.create_nd_tdesc %src[%off1, %off2]: memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
%res = xegpu.load_nd %desc : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
// to this
%desc = xegpu.create_nd_tdesc %src: memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
%res = xegpu.load_nd %desc[%off1, %off2] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
```
In order to support cases with dimension reduction at the
`create_nd_tdesc` level (e.g. `memref<8x8x16xf16> ->
tensor_desc<8x16xf16>` it was decided to insert a memref.subview that
collapses the source shape to 2d, for example:
```mlir
// input:
%0 = vector.load %source[%off0, %off1, %off2] : memref<8x16x32xf32>, vector<8x16xf32>
// --vector-to-xegpu (old)
%tdesc = xegpu.create_nd_tdesc %source[%off0, %off1, %off2] : memref<8x16x32xf32> -> tdesc<8x32xf32>
%vec = xegpu.load_nd %tdesc
// --vector-to-xegpu (new)
%collapsed = memref.subview %source[%off0, 0, 0] [1, 16, 32] [1, 1, 1] :
memref<8x16x32xf32> -> memref<16x32xf32, strided<[32, 1], offset: ?>>
%tdesc = xegpu.create_nd_tdesc %collapsed : memref<16x32xf32, ...> -> tdesc<8x32xf32>
%vec = xegpu.load_nd %tdesc[%off1, %off2]
```
<details><summary>Why we need to change that?</summary>
```mlir
// reduce dim and apply all 3 offsets at load_nd
%desc = xegpu.create_nd_tdesc %source : memref<8x16x32xf32> -> !xegpu.tensor_desc<16x32xf32>
// error: xegpu.load_nd len(offsets) != desc.rank
%res = xegpu.load_nd %desc[%off, %off, %off] : !xegpu.tensor_desc<16x32xf32> -> vector<8x16xf32>
```
</details>
---------
Signed-off-by: dchigarev <dmitry.chigarev at intel.com>
Added:
Modified:
mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 91c1aa55fdb4e..abea84f6b01fe 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -97,57 +97,23 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
return success();
}
-static xegpu::CreateNdDescOp
-createNdDescriptor(PatternRewriter &rewriter, Location loc,
- xegpu::TensorDescType descType, TypedValue<MemRefType> src,
- Operation::operand_range offsets) {
+static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter,
+ Location loc,
+ xegpu::TensorDescType descType,
+ TypedValue<MemRefType> src) {
MemRefType srcTy = src.getType();
auto [strides, offset] = srcTy.getStridesAndOffset();
xegpu::CreateNdDescOp ndDesc;
if (srcTy.hasStaticShape()) {
- ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
- getAsOpFoldResult(offsets));
+ ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src);
} else {
// In case of any dynamic shapes, source's shape and strides have to be
// explicitly provided.
- SmallVector<Value> sourceDims;
- unsigned srcRank = srcTy.getRank();
- for (unsigned i = 0; i < srcRank; ++i)
- sourceDims.push_back(memref::DimOp::create(rewriter, loc, src, i));
-
- SmallVector<int64_t> constOffsets;
- SmallVector<Value> dynOffsets;
- for (Value offset : offsets) {
- std::optional<int64_t> staticVal = getConstantIntValue(offset);
- if (!staticVal)
- dynOffsets.push_back(offset);
- constOffsets.push_back(staticVal.value_or(ShapedType::kDynamic));
- }
-
- SmallVector<Value> dynShapes;
- for (auto [idx, shape] : llvm::enumerate(srcTy.getShape())) {
- if (shape == ShapedType::kDynamic)
- dynShapes.push_back(sourceDims[idx]);
- }
-
- // Compute strides in reverse order.
- SmallVector<Value> dynStrides;
- Value accStride = arith::ConstantIndexOp::create(rewriter, loc, 1);
- // Last stride is guaranteed to be static and unit.
- for (int i = static_cast<int>(strides.size()) - 2; i >= 0; --i) {
- accStride =
- arith::MulIOp::create(rewriter, loc, accStride, sourceDims[i + 1]);
- if (strides[i] == ShapedType::kDynamic)
- dynStrides.push_back(accStride);
- }
- std::reverse(dynStrides.begin(), dynStrides.end());
-
- ndDesc = xegpu::CreateNdDescOp::create(
- rewriter, loc, descType, src, dynOffsets, dynShapes, dynStrides,
- DenseI64ArrayAttr::get(rewriter.getContext(), constOffsets),
- DenseI64ArrayAttr::get(rewriter.getContext(), srcTy.getShape()),
- DenseI64ArrayAttr::get(rewriter.getContext(), strides));
+ auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, src);
+ ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
+ meta.getConstifiedMixedSizes(),
+ meta.getConstifiedMixedStrides());
}
return ndDesc;
@@ -392,6 +358,62 @@ static Value computeOffsets(PatternRewriter &rewriter, OpType gatScatOp,
.getResult();
}
+// Collapses shapes of a nD memref to the target rank while applying offsets for
+// the collapsed dimensions. Returns the new memref value and the remaining
+// offsets for the last targetRank dimensions. For example:
+// input: %memref = memref<2x4x8x32xf32>, offsets=[%i0, %i1, %i2, %i3],
+// output: %memref[%i0, %i1, 0, 0] -> memref<8x32xf32>, offsets: [%i2, %i3]
+static std::pair<Value, SmallVector<OpFoldResult>>
+convertMemrefAndOffsetsToTargetRank(PatternRewriter &rewriter, Location loc,
+ Value memref,
+ SmallVector<OpFoldResult> offsets,
+ int64_t targetRank) {
+ auto memrefType = cast<MemRefType>(memref.getType());
+ unsigned rank = memrefType.getRank();
+
+ if (rank <= targetRank)
+ return {memref, offsets};
+
+ int64_t numCombinedDims = rank - targetRank;
+ SmallVector<OpFoldResult> subviewOffsets;
+ SmallVector<OpFoldResult> subviewSizes;
+ SmallVector<OpFoldResult> subviewStrides;
+
+ // For the combined dimensions: use the provided offsets, size=1, stride=1
+ for (unsigned i = 0; i < numCombinedDims; ++i) {
+ subviewOffsets.push_back(offsets[i]);
+ subviewSizes.push_back(rewriter.getI64IntegerAttr(1));
+ subviewStrides.push_back(rewriter.getI64IntegerAttr(1));
+ }
+
+ // For the last targetRank dimensions: offset=0, use full size, stride=1
+ SmallVector<int64_t> resultShape;
+ auto originalShape = memrefType.getShape();
+ auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, memref);
+ for (unsigned i = numCombinedDims; i < rank; ++i) {
+ subviewOffsets.push_back(rewriter.getI64IntegerAttr(0));
+ if (ShapedType::isDynamic(originalShape[i])) {
+ subviewSizes.push_back(meta.getSizes()[i]);
+ resultShape.push_back(ShapedType::kDynamic);
+ } else {
+ subviewSizes.push_back(rewriter.getI64IntegerAttr(originalShape[i]));
+ resultShape.push_back(originalShape[i]);
+ }
+ subviewStrides.push_back(rewriter.getI64IntegerAttr(1));
+ }
+
+ auto resultType = memref::SubViewOp::inferRankReducedResultType(
+ resultShape, memrefType, subviewOffsets, subviewSizes, subviewStrides);
+ auto subviewOp =
+ memref::SubViewOp::create(rewriter, loc, resultType, memref,
+ subviewOffsets, subviewSizes, subviewStrides);
+
+ // Return the remaining offsets for the last targetRank dimensions
+ SmallVector<OpFoldResult> newOffsets(offsets.begin() + numCombinedDims,
+ offsets.end());
+ return {subviewOp.getResult(), newOffsets};
+}
+
template <
typename OpType,
typename = std::enable_if_t<llvm::is_one_of<
@@ -523,18 +545,19 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
descShape, elementType, /*array_length=*/1,
/*boundary_check=*/isOutOfBounds, xegpu::MemorySpace::Global);
- xegpu::CreateNdDescOp ndDesc =
- createNdDescriptor(rewriter, loc, descType,
- dyn_cast<TypedValue<MemRefType>>(readOp.getBase()),
- readOp.getIndices());
-
DenseI64ArrayAttr transposeAttr =
!isTransposeLoad ? nullptr
: DenseI64ArrayAttr::get(rewriter.getContext(),
ArrayRef<int64_t>{1, 0});
+ auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
+ rewriter, loc, readOp.getBase(), getAsOpFoldResult(readOp.getIndices()),
+ vecTy.getRank());
// By default, no specific caching policy is assigned.
xegpu::CachePolicyAttr hint = nullptr;
- auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
+ xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
+ rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
+
+ auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, indices,
/*packed=*/nullptr, transposeAttr,
/*l1_hint=*/hint,
/*l2_hint=*/hint, /*l3_hint=*/hint);
@@ -575,21 +598,23 @@ struct TransferWriteLowering
if (!map.isMinorIdentity())
return rewriter.notifyMatchFailure(writeOp, "Expects identity map");
+ auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
+ rewriter, loc, writeOp.getBase(),
+ getAsOpFoldResult(writeOp.getIndices()), vecTy.getRank());
+
auto descType = xegpu::TensorDescType::get(
vecTy.getShape(), vecTy.getElementType(),
/*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(),
xegpu::MemorySpace::Global);
- xegpu::CreateNdDescOp ndDesc =
- createNdDescriptor(rewriter, loc, descType,
- dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()),
- writeOp.getIndices());
-
// By default, no specific caching policy is assigned.
xegpu::CachePolicyAttr hint = nullptr;
- auto storeOp =
- xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc,
- /*l1_hint=*/hint,
- /*l2_hint=*/hint, /*l3_hint=*/hint);
+ xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
+ rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
+
+ auto storeOp = xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(),
+ ndDesc, indices,
+ /*l1_hint=*/hint,
+ /*l2_hint=*/hint, /*l3_hint=*/hint);
rewriter.replaceOp(writeOp, storeOp);
return success();
@@ -674,19 +699,24 @@ struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
// Boundary check is available only for block instructions.
bool boundaryCheck = vecTy.getRank() > 1;
+ // By default, no specific caching policy is assigned.
+ xegpu::CachePolicyAttr hint = nullptr;
+
+ auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
+ rewriter, loc, loadOp.getBase(), getAsOpFoldResult(loadOp.getIndices()),
+ vecTy.getRank());
auto descType = xegpu::TensorDescType::get(
vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1,
boundaryCheck, xegpu::MemorySpace::Global);
- xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
- rewriter, loc, descType, loadOp.getBase(), loadOp.getIndices());
- // By default, no specific caching policy is assigned.
- xegpu::CachePolicyAttr hint = nullptr;
- auto loadNdOp = xegpu::LoadNdOp::create(
- rewriter, loc, vecTy, ndDesc, /*packed=*/nullptr, /*transpose=*/nullptr,
- /*l1_hint=*/hint,
- /*l2_hint=*/hint, /*l3_hint=*/hint);
+ xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
+ rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
+ auto loadNdOp =
+ xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc, indices,
+ /*packed=*/nullptr, /*transpose=*/nullptr,
+ /*l1_hint=*/hint,
+ /*l2_hint=*/hint, /*l3_hint=*/hint);
rewriter.replaceOp(loadOp, loadNdOp);
return success();
@@ -708,18 +738,24 @@ struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
// Boundary check is available only for block instructions.
bool boundaryCheck = vecTy.getRank() > 1;
+ auto [src, indices] = convertMemrefAndOffsetsToTargetRank(
+ rewriter, loc, storeOp.getBase(),
+ getAsOpFoldResult(storeOp.getIndices()), vecTy.getRank());
+
auto descType = xegpu::TensorDescType::get(
vecTy.getShape(), vecTy.getElementType(),
/*array_length=*/1, boundaryCheck, xegpu::MemorySpace::Global);
- xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
- rewriter, loc, descType, storeOp.getBase(), storeOp.getIndices());
// By default, no specific caching policy is assigned.
xegpu::CachePolicyAttr hint = nullptr;
+ xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
+ rewriter, loc, descType, dyn_cast<TypedValue<MemRefType>>(src));
+
auto storeNdOp =
- xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc,
+ xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc, indices,
/*l1_hint=*/hint,
/*l2_hint=*/hint, /*l3_hint=*/hint);
+
rewriter.replaceOp(storeOp, storeNdOp);
return success();
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 7b6c4b6c2c813..c8f5c86c03686 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -280,8 +280,10 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
// if shape and strides are from Memref, we don't need attributes for them
- // to keep the IR print clean.
- if (staticShape == memrefShape && staticStrides == memrefStrides) {
+ // to keep the IR print clean (only do so for full-static case, otherwise
+ // printer would fail trying to print empty array-attr).
+ if (staticShape == memrefShape && staticStrides == memrefStrides &&
+ dynamicShape.empty() && dynamicStrides.empty()) {
staticShapeAttr = DenseI64ArrayAttr();
staticStridesAttr = DenseI64ArrayAttr();
}
@@ -342,8 +344,10 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
// if shape and strides are from Memref, we don't need attributes for them
- // to keep the IR print clean.
- if (staticShape == memrefShape && staticStrides == memrefStrides) {
+ // to keep the IR print clean (only do so for full-static case, otherwise
+ // printer would fail trying to print empty array-attr).
+ if (staticShape == memrefShape && staticStrides == memrefStrides &&
+ dynamicShape.empty() && dynamicStrides.empty()) {
staticShapeAttr = DenseI64ArrayAttr();
staticStridesAttr = DenseI64ArrayAttr();
}
diff --git a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
index 9908205f07c92..ae5141db16c09 100644
--- a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
@@ -9,11 +9,12 @@ func.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vecto
// CHECK-LABEL: @load_1D_vector(
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
+// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0]
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
-// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
+// CHECK-SAME: %[[COLLAPSED]]
+// CHECK-SAME: memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32,
// CHECK-SAME: boundary_check = false
-// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf32>
+// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]]]{{.*}}-> vector<8xf32>
// CHECK: return %[[VEC]]
// -----
@@ -28,35 +29,29 @@ func.func @load_2D_vector(%source: memref<8x16x32xf32>,
// CHECK-LABEL: @load_2D_vector(
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
+// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
-// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// CHECK-SAME: %[[COLLAPSED]]
+// CHECK-SAME: memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
// CHECK: return %[[VEC]]
// -----
func.func @load_dynamic_source(%source: memref<?x?x?xf32>,
- %offset: index) -> vector<8x16xf32> {
- %0 = vector.load %source[%offset, %offset, %offset]
+ %i: index, %j: index, %k: index) -> vector<8x16xf32> {
+ %0 = vector.load %source[%i, %j, %k]
: memref<?x?x?xf32>, vector<8x16xf32>
return %0 : vector<8x16xf32>
}
// CHECK-LABEL: @load_dynamic_source(
// CHECK-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
-// CHECK-SAME: %[[OFFSET:.+]]: index
-// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
-// CHECK-DAG: %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
-// CHECK-DAG: %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
-// CHECK-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
-// CHECK: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
-// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
-// CHECK-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
-// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// CHECK-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
+// CHECK: {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
+// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF1]], %[[OFF2]]]{{.*}}-> vector<8x16xf32>
// CHECK: return %[[VEC]]
// -----
@@ -72,9 +67,9 @@ func.func @load_out_of_bounds(%source: memref<7x15xf32>,
// CHECK-SAME: %[[SRC:.+]]: memref<7x15xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME: %[[SRC]]
// CHECK-SAME: memref<7x15xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
// CHECK: return %[[VEC]]
// -----
diff --git a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
index 2c498dcc2a071..1a10d917623cc 100644
--- a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
@@ -11,11 +11,12 @@ func.func @store_1D_vector(%vec: vector<8xf32>,
// CHECK-SAME: %[[VEC:.+]]: vector<8xf32>,
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
+// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0]
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
-// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
+// CHECK-SAME: %[[COLLAPSED]]
+// CHECK-SAME: memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32,
// CHECK-SAME: boundary_check = false
-// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8xf32>
+// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]]] : vector<8xf32>
// -----
@@ -30,16 +31,17 @@ func.func @store_2D_vector(%vec: vector<8x16xf32>,
// CHECK-SAME: %[[VEC:.+]]: vector<8x16xf32>,
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
+// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
-// CHECK-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+// CHECK-SAME: %[[COLLAPSED]]
+// CHECK-SAME: memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
// -----
func.func @store_dynamic_source(%vec: vector<8x16xf32>,
- %source: memref<?x?x?xf32>, %offset: index) {
- vector.store %vec, %source[%offset, %offset, %offset]
+ %source: memref<?x?x?xf32>, %i: index, %j: index, %k: index) {
+ vector.store %vec, %source[%i, %j, %k]
: memref<?x?x?xf32>, vector<8x16xf32>
return
}
@@ -47,18 +49,11 @@ func.func @store_dynamic_source(%vec: vector<8x16xf32>,
// CHECK-LABEL: @store_dynamic_source(
// CHECK-SAME: %[[VEC:.+]]: vector<8x16xf32>,
// CHECK-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
-// CHECK-SAME: %[[OFFSET:.+]]: index
-// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
-// CHECK-DAG: %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
-// CHECK-DAG: %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
-// CHECK-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
-// CHECK: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
-// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
-// CHECK-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
-// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+// CHECK-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
+// CHECK: {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
+// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFF1]], %[[OFF2]]] : vector<8x16xf32>
// -----
@@ -74,9 +69,9 @@ func.func @store_out_of_bounds(%vec: vector<8x16xf32>,
// CHECK-SAME: %[[SRC:.+]]: memref<7x64xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME: %[[SRC]]
// CHECK-SAME: memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
// -----
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
index c4ca79af1bd9a..c87a5304babfe 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
@@ -12,11 +12,12 @@ gpu.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vector
// LOAD-ND-LABEL: @load_1D_vector(
// LOAD-ND-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// LOAD-ND-SAME: %[[OFFSET:.+]]: index
+// LOAD-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0]
// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// LOAD-ND-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
-// LOAD-ND-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
+// LOAD-ND-SAME: %[[COLLAPSED]]
+// LOAD-ND-SAME: memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32,
// LOAD-ND-SAME: boundary_check = false
-// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf32>
+// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]]]{{.*}}-> vector<8xf32>
// LOAD-ND: return %[[VEC]]
// LOAD-GATHER-LABEL: @load_1D_vector(
@@ -46,11 +47,12 @@ gpu.func @load_2D_vector(%source: memref<8x16x32xf32>,
// LOAD-ND-LABEL: @load_2D_vector(
// LOAD-ND-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// LOAD-ND-SAME: %[[OFFSET:.+]]: index
+// LOAD-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// LOAD-ND-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
-// LOAD-ND-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
+// LOAD-ND-SAME: %[[COLLAPSED]]
+// LOAD-ND-SAME: memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32,
// LOAD-ND-SAME: boundary_check = false
-// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
// LOAD-ND: return %[[VEC]]
// LOAD-GATHER-LABEL: @load_2D_vector(
@@ -83,9 +85,9 @@ gpu.func @load_zero_pad_out_of_bounds(%source: memref<32x64xf32>,
// LOAD-ND-LABEL: @load_zero_pad_out_of_bounds(
// LOAD-ND-SAME: %[[SRC:.+]]: memref<32x64xf32>,
// LOAD-ND-SAME: %[[OFFSET:.+]]: index
-// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
+// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
// LOAD-ND-SAME: memref<32x64xf32> -> !xegpu.tensor_desc<8x16xf32>
-// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
// LOAD-ND: return %[[VEC]]
// LOAD-GATHER-LABEL: @load_zero_pad_out_of_bounds(
@@ -109,9 +111,9 @@ gpu.func @load_transposed(%source: memref<32x64xf32>,
// LOAD-ND-SAME: %[[SRC:.+]]: memref<32x64xf32>,
// LOAD-ND-SAME: %[[OFFSET1:.+]]: index,
// LOAD-ND-SAME: %[[OFFSET2:.+]]: index
-// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET1]], %[[OFFSET2]]]
+// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]]
// LOAD-ND-SAME: memref<32x64xf32> -> !xegpu.tensor_desc<16x8xf32
-// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]] <{transpose = array<i64: 1, 0>}>
+// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET1]], %[[OFFSET2]]] <{transpose = array<i64: 1, 0>}>
// LOAD-ND-SAME: -> vector<8x16xf32>
// LOAD-ND: return %[[VEC]]
@@ -143,16 +145,11 @@ gpu.func @load_dynamic_source(%source: memref<?x?x?xf32>,
}
// LOAD-ND-LABEL: @load_dynamic_source(
// LOAD-ND-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
-// LOAD-ND-SAME: %[[OFFSET:.+]]: index
-// LOAD-ND: %[[C2:.+]] = arith.constant 2 : index
-// LOAD-ND: %[[C1:.+]] = arith.constant 1 : index
-// LOAD-ND: %[[C0:.+]] = arith.constant 0 : index
-// LOAD-ND-DAG: %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
-// LOAD-ND-DAG: %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
-// LOAD-ND-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
-// LOAD-ND: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
-// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET:.+]], %[[OFFSET:.+]], %[[OFFSET:.+]]]
-// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// LOAD-ND-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// LOAD-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
+// LOAD-ND: {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
+// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF1]], %[[OFF2]]]{{.*}}-> vector<8x16xf32>
// LOAD-ND: return %[[VEC]]
@@ -184,10 +181,11 @@ gpu.func @load_dynamic_source2(%source: memref<?x8x16xf32>,
}
// LOAD-ND-LABEL: @load_dynamic_source2(
-// LOAD-ND-DAG: %[[C0:.+]] = arith.constant 0 : index
-// LOAD-ND-DAG: %[[DIM:.+]] = memref.dim %{{.*}}, %[[C0]] : memref<?x8x16xf32>
-// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}], shape : [%[[DIM]], 8, 16], strides : [128, 16, 1] : memref<?x8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
-// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]] : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>> -> vector<8x16xf32>
+// LOAD-ND-SAME: %[[SRC:.+]]: memref<?x8x16xf32>,
+// LOAD-ND-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// LOAD-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
+// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32, strided<[16, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
+// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%{{.*}}, %{{.*}}] : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>> -> vector<8x16xf32>
// LOAD-ND: return %[[VEC]] : vector<8x16xf32>
// LOAD-GATHER-LABEL: @load_dynamic_source2(
@@ -418,11 +416,12 @@ gpu.func @load_from_subview(%source: memref<4096x4096xf16>, %off1: index, %off2:
// LOAD-ND-SAME: %[[SRC:.+]]: memref<4096x4096xf16>,
// LOAD-ND-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
// LOAD-ND: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
+// LOAD-ND: %[[COLLAPSED:.+]] = memref.subview %[[SUBVIEW]][%[[OFF2]], 0]
// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// LOAD-ND-SAME: %[[SUBVIEW]][%[[OFF2]], %[[OFF2]]]
-// LOAD-ND-SAME: memref<256x256xf16, strided<[4096, 1], offset: ?>> -> !xegpu.tensor_desc<8xf16,
+// LOAD-ND-SAME: %[[COLLAPSED]]
+// LOAD-ND-SAME: memref<256xf16, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf16,
// LOAD-ND-SAME: boundary_check = false
-// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf16>
+// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF2]]]{{.*}}-> vector<8xf16>
// LOAD-ND: return %[[VEC]]
// LOAD-GATHER-LABEL: @load_from_subview(
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
index fcfc9414da4f6..43a1a7206e2cc 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --xevm-attach-target='module=xevm.* O=3 chip=pvc' -convert-vector-to-xegpu -split-input-file | FileCheck %s --check-prefix=STORE-ND
+// RUN: mlir-opt %s --xevm-attach-target='module=xevm_* O=3 chip=pvc' -convert-vector-to-xegpu -split-input-file | FileCheck %s --check-prefix=STORE-ND
// RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s --check-prefix=STORE-SCATTER
@@ -15,11 +15,12 @@ gpu.func @store_1D_vector(%vec: vector<8xf32>,
// STORE-ND-SAME: %[[VEC:.+]]: vector<8xf32>,
// STORE-ND-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// STORE-ND-SAME: %[[OFFSET:.+]]: index
+// STORE-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0]
// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// STORE-ND-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
-// STORE-ND-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
+// STORE-ND-SAME: %[[COLLAPSED]]
+// STORE-ND-SAME: memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32,
// STORE-ND-SAME: boundary_check = false
-// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8xf32>
+// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]]] : vector<8xf32>
// STORE-SCATTER-LABEL: @store_1D_vector(
// STORE-SCATTER-SAME: %[[VEC:.+]]: vector<8xf32>,
@@ -49,11 +50,12 @@ gpu.func @store_2D_vector(%vec: vector<8x16xf32>,
// STORE-ND-SAME: %[[VEC:.+]]: vector<8x16xf32>,
// STORE-ND-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// STORE-ND-SAME: %[[OFFSET:.+]]: index
+// STORE-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// STORE-ND-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
-// STORE-ND-SAME: memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
+// STORE-ND-SAME: %[[COLLAPSED]]
+// STORE-ND-SAME: memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32,
// STORE-ND-SAME: boundary_check = false
-// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
// STORE-SCATTER-LABEL: @store_2D_vector(
// STORE-SCATTER-SAME: %[[VEC:.+]]: vector<8x16xf32>,
@@ -73,8 +75,8 @@ gpu.func @store_2D_vector(%vec: vector<8x16xf32>,
// -----
gpu.module @xevm_module {
gpu.func @store_dynamic_source(%vec: vector<8x16xf32>,
- %source: memref<?x?x?xf32>, %offset: index) {
- vector.transfer_write %vec, %source[%offset, %offset, %offset]
+ %source: memref<?x?x?xf32>, %i: index, %j: index, %k: index) {
+ vector.transfer_write %vec, %source[%i, %j, %k]
{in_bounds = [true, true]}
: vector<8x16xf32>, memref<?x?x?xf32>
gpu.return
@@ -83,18 +85,11 @@ gpu.func @store_dynamic_source(%vec: vector<8x16xf32>,
// STORE-ND-LABEL: @store_dynamic_source(
// STORE-ND-SAME: %[[VEC:.+]]: vector<8x16xf32>,
// STORE-ND-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
-// STORE-ND-SAME: %[[OFFSET:.+]]: index
-// STORE-ND-DAG: %[[C0:.+]] = arith.constant 0 : index
-// STORE-ND-DAG: %[[C1:.+]] = arith.constant 1 : index
-// STORE-ND-DAG: %[[C2:.+]] = arith.constant 2 : index
-// STORE-ND-DAG: %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
-// STORE-ND-DAG: %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
-// STORE-ND-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
-// STORE-ND: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
-// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
-// STORE-ND-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
-// STORE-ND-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32
-// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+// STORE-ND-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// STORE-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
+// STORE-ND: {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
+// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFF1]], %[[OFF2]]] : vector<8x16xf32>
// STORE-SCATTER-LABEL: @store_dynamic_source(
// STORE-SCATTER-SAME: %[[VEC:.+]]: vector<8x16xf32>,
@@ -126,9 +121,9 @@ gpu.func @store_out_of_bounds(%vec: vector<8x16xf32>,
// STORE-ND-SAME: %[[SRC:.+]]: memref<7x64xf32>,
// STORE-ND-SAME: %[[OFFSET:.+]]: index
// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// STORE-ND-SAME: %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
+// STORE-ND-SAME: %[[SRC]]
// STORE-ND-SAME: memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32>
-// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
// STORE-SCATTER-LABEL: @store_out_of_bounds(
// STORE-SCATTER: vector.transfer_write
@@ -298,13 +293,13 @@ gpu.func @store_to_subview(%vec: vector<8xf16>,
// STORE-ND-SAME: %[[VEC:.+]]: vector<8xf16>,
// STORE-ND-SAME: %[[SRC:.+]]: memref<4096x4096xf16>,
// STORE-ND-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
-// STORE-ND: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1]
-// STORE-ND-SAME: : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
+// STORE-ND: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
+// STORE-ND: %[[COLLAPSED:.+]] = memref.subview %[[SUBVIEW]][%[[OFF2]], 0]
// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// STORE-ND-SAME: %[[SUBVIEW]][%[[OFF2]], %[[OFF2]]]
-// STORE-ND-SAME: memref<256x256xf16, strided<[4096, 1], offset: ?>> -> !xegpu.tensor_desc<8xf16,
+// STORE-ND-SAME: %[[COLLAPSED]]
+// STORE-ND-SAME: memref<256xf16, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf16,
// STORE-ND-SAME: boundary_check = false
-// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8xf16>
+// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFF2]]] : vector<8xf16>
// STORE-SCATTER-LABEL: @store_to_subview(
// STORE-SCATTER-SAME: %[[VEC:.+]]: vector<8xf16>,
More information about the Mlir-commits
mailing list