[Mlir-commits] [mlir] 8b0a24a - [MLIR] Vector to XeGPU conversion: Use proper source variant for create_nd_tdesc op creation. (#171216)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Dec 18 11:03:56 PST 2025
Author: Sang Ik Lee
Date: 2025-12-18T11:03:51-08:00
New Revision: 8b0a24a50da54587b183ff7ff8f7144f6a0aed85
URL: https://github.com/llvm/llvm-project/commit/8b0a24a50da54587b183ff7ff8f7144f6a0aed85
DIFF: https://github.com/llvm/llvm-project/commit/8b0a24a50da54587b183ff7ff8f7144f6a0aed85.diff
LOG: [MLIR] Vector to XeGPU conversion: Use proper source variant for create_nd_tdesc op creation. (#171216)
If source strided memref is not fully static - at least one of shape,
strides, offset is kDynamic - use i64 source variant.
With this change, xegpu.create_nd_tdesc created by lowering from vector
dialect, can rely on getMixedOffsets, getMixedSize and getMixedStrides
to get relevant values.
Added:
Modified:
mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 079e1e2a8ac67..55ade0ae8eeec 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -102,18 +102,46 @@ static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter,
xegpu::TensorDescType descType,
TypedValue<MemRefType> src) {
MemRefType srcTy = src.getType();
+ assert(srcTy.isStrided() && "Expected strided memref type");
auto [strides, offset] = srcTy.getStridesAndOffset();
+ bool isStatic = true;
+
+ // Memref is dynamic if any of its shape, offset or strides is dynamic.
+ if (!srcTy.hasStaticShape())
+ isStatic = false;
+
+ if (!ShapedType::isStatic(offset))
+ isStatic = false;
+
+ for (auto stride : strides) {
+ if (!ShapedType::isStatic(stride)) {
+ isStatic = false;
+ break;
+ }
+ }
xegpu::CreateNdDescOp ndDesc;
- if (srcTy.hasStaticShape()) {
+ if (isStatic) {
ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src);
} else {
- // In case of any dynamic shapes, source's shape and strides have to be
+ // In case of ranked dynamic memref, instead of passing on the memref,
+ // i64 base address, source's offset, shape and strides have to be
// explicitly provided.
auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, src);
- ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
- meta.getConstifiedMixedSizes(),
- meta.getConstifiedMixedStrides());
+ auto baseAddrIndex = memref::ExtractAlignedPointerAsIndexOp::create(
+ rewriter, loc, meta.getBaseBuffer());
+ auto offset = meta.getOffset();
+ auto elemByteSize = srcTy.getElementTypeBitWidth() / 8;
+ auto offsetInBytes = arith::MulIOp::create(
+ rewriter, loc, offset,
+ arith::ConstantIndexOp::create(rewriter, loc, elemByteSize));
+ auto adjustedBaseAddr = arith::AddIOp::create(
+ rewriter, loc, baseAddrIndex.getResult(), offsetInBytes);
+ auto adjustedAddrI64 = arith::IndexCastOp::create(
+ rewriter, loc, rewriter.getI64Type(), adjustedBaseAddr);
+ ndDesc = xegpu::CreateNdDescOp::create(
+ rewriter, loc, descType, adjustedAddrI64,
+ meta.getConstifiedMixedSizes(), meta.getConstifiedMixedStrides());
}
return ndDesc;
diff --git a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
index ae5141db16c09..c77efa03f3483 100644
--- a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
@@ -9,10 +9,17 @@ func.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vecto
// CHECK-LABEL: @load_1D_vector(
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
+// CHECK: %[[ELEM_BYTES:.+]] = arith.constant 4 : index
// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0]
-// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME: %[[COLLAPSED]]
-// CHECK-SAME: memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32,
+// CHECK: %[[BASE_BUFFER:.+]], %[[OFFSET1:.+]], %[[SIZES:.+]], %[[STRIDES:.+]] = memref.extract_strided_metadata %[[COLLAPSED]]
+// CHECK-SAME: : memref<32xf32, strided<[1], offset: ?>> -> memref<f32>, index, index, index
+// CHECK: %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// CHECK-SAME: : memref<f32> -> index
+// CHECK: %[[MUL:.+]] = arith.muli %[[OFFSET1]], %[[ELEM_BYTES]] : index
+// CHECK: %[[ADD:.+]] = arith.addi %[[INTPTR]], %[[MUL]] : index
+// CHECK: %[[I64PTR:.+]] = arith.index_cast %[[ADD]] : index to i64
+// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [32],
+// CHECK-SAME: strides : [1] : i64 -> !xegpu.tensor_desc<8xf32,
// CHECK-SAME: boundary_check = false
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]]]{{.*}}-> vector<8xf32>
// CHECK: return %[[VEC]]
@@ -29,10 +36,16 @@ func.func @load_2D_vector(%source: memref<8x16x32xf32>,
// CHECK-LABEL: @load_2D_vector(
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
+// CHECK: %[[ELEM_BYTES:.+]] = arith.constant 4 : index
// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
-// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME: %[[COLLAPSED]]
-// CHECK-SAME: memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK: %[[BASE_BUFFER:.*]], %[[OFF1:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// CHECK-SAME: : memref<f32> -> index
+// CHECK: %[[MUL:.+]] = arith.muli %[[OFF1]], %[[ELEM_BYTES]] : index
+// CHECK: %[[ADD:.+]] = arith.addi %[[INTPTR]], %[[MUL]] : index
+// CHECK: %[[I64PTR:.+]] = arith.index_cast %[[ADD]] : index to i64
+// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [16, 32],
+// CHECK-SAME: strides : [32, 1] : i64 -> !xegpu.tensor_desc<8x16xf32>
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
// CHECK: return %[[VEC]]
@@ -48,9 +61,15 @@ func.func @load_dynamic_source(%source: memref<?x?x?xf32>,
// CHECK-LABEL: @load_dynamic_source(
// CHECK-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
// CHECK-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// CHECK: %[[ELEM_BYTES:.+]] = arith.constant 4 : index
// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
-// CHECK: {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
-// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
+// CHECK: %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// CHECK: %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]] : memref<f32> -> index
+// CHECK: %[[MUL:.+]] = arith.muli %[[OFFSET]], %[[ELEM_BYTES]] : index
+// CHECK: %[[ADD:.+]] = arith.addi %[[INTPTR]], %[[MUL]] : index
+// CHECK: %[[I64PTR:.+]] = arith.index_cast %[[ADD]] : index to i64
+// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [%[[SIZES]]#0, %[[SIZES]]#1],
+// CHECK-SAME: strides : [%[[STRIDES]]#0, 1] : i64 -> !xegpu.tensor_desc<8x16xf32>
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF1]], %[[OFF2]]]{{.*}}-> vector<8x16xf32>
// CHECK: return %[[VEC]]
diff --git a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
index 1a10d917623cc..3c11313d05536 100644
--- a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
@@ -11,10 +11,17 @@ func.func @store_1D_vector(%vec: vector<8xf32>,
// CHECK-SAME: %[[VEC:.+]]: vector<8xf32>,
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
+// CHECK: %[[ELEM_BYTES:.*]] = arith.constant 4 : index
// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0]
-// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME: %[[COLLAPSED]]
-// CHECK-SAME: memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32,
+// CHECK: %[[BASE_BUFFER:.+]], %[[OFFSET1:.+]], %[[SIZES:.+]], %[[STRIDES:.+]] = memref.extract_strided_metadata %[[COLLAPSED]]
+// CHECK-SAME: : memref<32xf32, strided<[1], offset: ?>> -> memref<f32>, index, index, index
+// CHECK: %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// CHECK-SAME: : memref<f32> -> index
+// CHECK: %[[MUL:.+]] = arith.muli %[[OFFSET1]], %[[ELEM_BYTES]] : index
+// CHECK: %[[ADD:.+]] = arith.addi %[[INTPTR]], %[[MUL]] : index
+// CHECK: %[[I64PTR:.+]] = arith.index_cast %[[ADD]] : index to i64
+// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [32],
+// CHECK-SAME: strides : [1] : i64 -> !xegpu.tensor_desc<8xf32,
// CHECK-SAME: boundary_check = false
// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]]] : vector<8xf32>
@@ -31,10 +38,16 @@ func.func @store_2D_vector(%vec: vector<8x16xf32>,
// CHECK-SAME: %[[VEC:.+]]: vector<8x16xf32>,
// CHECK-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// CHECK-SAME: %[[OFFSET:.+]]: index
+// CHECK: %[[ELEM_BYTES:.*]] = arith.constant 4 : index
// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
-// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME: %[[COLLAPSED]]
-// CHECK-SAME: memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK: %[[BASE_BUFFER:.*]], %[[OFF1:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// CHECK-SAME: : memref<f32> -> index
+// CHECK: %[[MUL:.+]] = arith.muli %[[OFF1]], %[[ELEM_BYTES]] : index
+// CHECK: %[[ADD:.+]] = arith.addi %[[INTPTR]], %[[MUL]] : index
+// CHECK: %[[I64PTR:.+]] = arith.index_cast %[[ADD]] : index to i64
+// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [16, 32],
+// CHECK-SAME: strides : [32, 1] : i64 -> !xegpu.tensor_desc<8x16xf32>
// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
// -----
@@ -50,9 +63,15 @@ func.func @store_dynamic_source(%vec: vector<8x16xf32>,
// CHECK-SAME: %[[VEC:.+]]: vector<8x16xf32>,
// CHECK-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
// CHECK-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// CHECK: %[[ELEM_BYTES:.*]] = arith.constant 4 : index
// CHECK: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
-// CHECK: {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
-// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
+// CHECK: %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// CHECK: %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]] : memref<f32> -> index
+// CHECK: %[[MUL:.+]] = arith.muli %[[OFFSET]], %[[ELEM_BYTES]] : index
+// CHECK: %[[ADD:.+]] = arith.addi %[[INTPTR]], %[[MUL]] : index
+// CHECK: %[[I64PTR:.+]] = arith.index_cast %[[ADD]] : index to i64
+// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [%[[SIZES]]#0, %[[SIZES]]#1],
+// CHECK-SAME: strides : [%[[STRIDES]]#0, 1] : i64 -> !xegpu.tensor
// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFF1]], %[[OFF2]]] : vector<8x16xf32>
// -----
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
index 8bb272b1fe5fc..b58f9b30ed726 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
@@ -48,10 +48,16 @@ gpu.func @load_2D_vector(%source: memref<8x16x32xf32>,
// LOAD-ND-LABEL: @load_2D_vector(
// LOAD-ND-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// LOAD-ND-SAME: %[[OFFSET:.+]]: index
+// LOAD-ND: %[[ELEM_BYTES:.+]] = arith.constant 4 : index
// LOAD-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
-// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// LOAD-ND-SAME: %[[COLLAPSED]]
-// LOAD-ND-SAME: memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32,
+// LOAD-ND: %[[BASE_BUFFER:.*]], %[[OFF1:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// LOAD-ND: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// LOAD-ND-SAME: : memref<f32> -> index
+// LOAD-ND: %[[MUL:.*]] = arith.muli %[[OFF1]], %[[ELEM_BYTES]] : index
+// LOAD-ND: %[[ADD:.*]] = arith.addi %[[INTPTR]], %[[MUL]] : index
+// LOAD-ND: %[[I64PTR:.*]] = arith.index_cast %[[ADD]] : index to i64
+// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [16, 32],
+// LOAD-ND-SAME: strides : [32, 1] : i64 -> !xegpu.tensor_desc<8x16xf32,
// LOAD-ND-SAME: boundary_check = false
// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
// LOAD-ND: return %[[VEC]]
@@ -147,9 +153,16 @@ gpu.func @load_dynamic_source(%source: memref<?x?x?xf32>,
// LOAD-ND-LABEL: @load_dynamic_source(
// LOAD-ND-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
// LOAD-ND-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// LOAD-ND: %[[ELEM_BYTES:.+]] = arith.constant 4 : index
// LOAD-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
-// LOAD-ND: {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
-// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
+// LOAD-ND: %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// LOAD-ND: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]] : memref<f32> -> index
+// LOAD-ND: %[[MUL:.*]] = arith.muli %[[OFFSET]], %[[ELEM_BYTES]] : index
+// LOAD-ND: %[[ADD:.*]] = arith.addi %[[INTPTR]], %[[MUL]] : index
+// LOAD-ND: %[[I64PTR:.*]] = arith.index_cast %[[ADD]] : index to i64
+// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [%[[SIZES]]#0, %[[SIZES]]#1],
+// LOAD-ND-SAME: strides : [%[[STRIDES]]#0, 1] : i64 -> !xegpu.tensor_desc<8x16xf32,
+// LOAD-ND-SAME: #xegpu.block_tdesc_attr<boundary_check = false>>
// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF1]], %[[OFF2]]]{{.*}}-> vector<8x16xf32>
// LOAD-ND: return %[[VEC]]
@@ -184,8 +197,15 @@ gpu.func @load_dynamic_source2(%source: memref<?x8x16xf32>,
// LOAD-ND-LABEL: @load_dynamic_source2(
// LOAD-ND-SAME: %[[SRC:.+]]: memref<?x8x16xf32>,
// LOAD-ND-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// LOAD-ND: %[[ELEM_BYTES:.+]] = arith.constant 4 : index
// LOAD-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
-// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32, strided<[16, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
+// LOAD-ND: %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// LOAD-ND: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// LOAD-ND: %[[MUL:.*]] = arith.muli %[[OFFSET]], %[[ELEM_BYTES]] : index
+// LOAD-ND: %[[ADD:.*]] = arith.addi %[[INTPTR]], %[[MUL]] : index
+// LOAD-ND: %[[I64PTR:.*]] = arith.index_cast %[[ADD]] : index to i64
+// LOAD-ND: %[[DESC:.*]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [8, 16], strides : [16, 1] :
+// LOAD-ND-SAME: i64 -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%{{.*}}, %{{.*}}] : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>> -> vector<8x16xf32>
// LOAD-ND: return %[[VEC]] : vector<8x16xf32>
@@ -459,11 +479,15 @@ gpu.func @load_from_subview_2D(%source: memref<4096x4096xf16>, %off1: index, %of
// LOAD-ND-LABEL: @load_from_subview_2D(
// LOAD-ND-SAME: %[[SRC:.+]]: memref<4096x4096xf16>,
// LOAD-ND-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// LOAD-ND: %[[ELEM_BYTES:.+]] = arith.constant 2 : index
// LOAD-ND: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
-// LOAD-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// LOAD-ND-SAME: %[[SUBVIEW]]
-// LOAD-ND-SAME: memref<256x256xf16, strided<[4096, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf16,
-// LOAD-ND-SAME: boundary_check = false
+// LOAD-ND: %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[SUBVIEW]]
+// LOAD-ND: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// LOAD-ND: %[[MUL:.*]] = arith.muli %[[OFFSET]], %[[ELEM_BYTES]] : index
+// LOAD-ND: %[[ADD:.*]] = arith.addi %[[INTPTR]], %[[MUL]] : index
+// LOAD-ND: %[[I64PTR:.*]] = arith.index_cast %[[ADD]] : index to i64
+// LOAD-ND: %[[DESC:.*]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [256, 256], strides : [4096, 1] :
+// LOAD-ND-SAME: i64 -> !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<boundary_check = false>>
// LOAD-ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF2]], %[[OFF2]]]{{.*}}-> vector<8x16xf16>
// LOAD-ND: return %[[VEC]]
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
index 43a1a7206e2cc..66da64225678e 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
@@ -15,10 +15,17 @@ gpu.func @store_1D_vector(%vec: vector<8xf32>,
// STORE-ND-SAME: %[[VEC:.+]]: vector<8xf32>,
// STORE-ND-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// STORE-ND-SAME: %[[OFFSET:.+]]: index
+// STORE-ND: %[[ELEM_BYTES:.+]] = arith.constant 4 : index
// STORE-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0]
-// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// STORE-ND-SAME: %[[COLLAPSED]]
-// STORE-ND-SAME: memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32,
+// STORE-ND: %[[BASE_BUFFER:.+]], %[[OFFSET1:.+]], %[[SIZES:.+]], %[[STRIDES:.+]] = memref.extract_strided_metadata %[[COLLAPSED]]
+// STORE-ND-SAME: : memref<32xf32, strided<[1], offset: ?>> -> memref<f32>, index, index, index
+// STORE-ND: %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// STORE-ND-SAME: : memref<f32> -> index
+// STORE-ND: %[[MUL:.+]] = arith.muli %[[OFFSET1]], %[[ELEM_BYTES]] : index
+// STORE-ND: %[[ADD:.+]] = arith.addi %[[INTPTR]], %[[MUL]] : index
+// STORE-ND: %[[I64PTR:.+]] = arith.index_cast %[[ADD]] : index to i64
+// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [32],
+// STORE-ND-SAME: strides : [1] : i64 -> !xegpu.tensor_desc<8xf32,
// STORE-ND-SAME: boundary_check = false
// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]]] : vector<8xf32>
@@ -50,10 +57,16 @@ gpu.func @store_2D_vector(%vec: vector<8x16xf32>,
// STORE-ND-SAME: %[[VEC:.+]]: vector<8x16xf32>,
// STORE-ND-SAME: %[[SRC:.+]]: memref<8x16x32xf32>,
// STORE-ND-SAME: %[[OFFSET:.+]]: index
+// STORE-ND: %[[ELEM_BYTES:.+]] = arith.constant 4 : index
// STORE-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
-// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// STORE-ND-SAME: %[[COLLAPSED]]
-// STORE-ND-SAME: memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32,
+// STORE-ND: %[[BASE_BUFFER:.*]], %[[OFF1:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// STORE-ND: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// STORE-ND-SAME: : memref<f32> -> index
+// STORE-ND: %[[MUL:.+]] = arith.muli %[[OFF1]], %[[ELEM_BYTES]] : index
+// STORE-ND: %[[ADD:.+]] = arith.addi %[[INTPTR]], %[[MUL]] : index
+// STORE-ND: %[[I64PTR:.*]] = arith.index_cast %[[ADD]] : index to i64
+// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [16, 32],
+// STORE-ND-SAME: strides : [32, 1] : i64 -> !xegpu.tensor_desc<8x16xf32,
// STORE-ND-SAME: boundary_check = false
// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
@@ -86,9 +99,15 @@ gpu.func @store_dynamic_source(%vec: vector<8x16xf32>,
// STORE-ND-SAME: %[[VEC:.+]]: vector<8x16xf32>,
// STORE-ND-SAME: %[[SRC:.+]]: memref<?x?x?xf32>,
// STORE-ND-SAME: %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// STORE-ND: %[[ELEM_BYTES:.+]] = arith.constant 4 : index
// STORE-ND: %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
-// STORE-ND: {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
-// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
+// STORE-ND: %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// STORE-ND: %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]] : memref<f32> -> index
+// STORE-ND: %[[MUL:.+]] = arith.muli %[[OFFSET]], %[[ELEM_BYTES]] : index
+// STORE-ND: %[[ADD:.+]] = arith.addi %[[INTPTR]], %[[MUL]] : index
+// STORE-ND: %[[I64PTR:.*]] = arith.index_cast %[[ADD]] : index to i64
+// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [%[[SIZES]]#0, %[[SIZES]]#1],
+// STORE-ND-SAME: strides : [%[[STRIDES]]#0, 1] : i64 -> !xegpu.tensor
// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFF1]], %[[OFF2]]] : vector<8x16xf32>
// STORE-SCATTER-LABEL: @store_dynamic_source(
@@ -293,12 +312,16 @@ gpu.func @store_to_subview(%vec: vector<8xf16>,
// STORE-ND-SAME: %[[VEC:.+]]: vector<8xf16>,
// STORE-ND-SAME: %[[SRC:.+]]: memref<4096x4096xf16>,
// STORE-ND-SAME: %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// STORE-ND: %[[ELEM_BYTES:.+]] = arith.constant 2 : index
// STORE-ND: %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
// STORE-ND: %[[COLLAPSED:.+]] = memref.subview %[[SUBVIEW]][%[[OFF2]], 0]
-// STORE-ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// STORE-ND-SAME: %[[COLLAPSED]]
-// STORE-ND-SAME: memref<256xf16, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf16,
-// STORE-ND-SAME: boundary_check = false
+// STORE-ND: %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.*]], %[[STRIDES:.*]] = memref.extract_strided_metadata %[[COLLAPSED]]
+// STORE-ND: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// STORE-ND: %[[MUL:.+]] = arith.muli %[[OFFSET]], %[[ELEM_BYTES]] : index
+// STORE-ND: %[[ADD:.+]] = arith.addi %[[INTPTR]], %[[MUL]] : index
+// STORE-ND: %[[I64PTR:.*]] = arith.index_cast %[[ADD]] : index to i64
+// STORE-ND: %[[DESC:.*]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [256], strides : [1] : i64 ->
+// STORE-ND-SAME: !xegpu.tensor_desc<8xf16, #xegpu.block_tdesc_attr<boundary_check = false>>
// STORE-ND: xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFF2]]] : vector<8xf16>
// STORE-SCATTER-LABEL: @store_to_subview(
More information about the Mlir-commits
mailing list