[Mlir-commits] [mlir] 8b0a24a - [MLIR] Vector to XeGPU conversion: Use proper source variant for create_nd_tdesc op creation. (#171216)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Dec 18 11:03:56 PST 2025


Author: Sang Ik Lee
Date: 2025-12-18T11:03:51-08:00
New Revision: 8b0a24a50da54587b183ff7ff8f7144f6a0aed85

URL: https://github.com/llvm/llvm-project/commit/8b0a24a50da54587b183ff7ff8f7144f6a0aed85
DIFF: https://github.com/llvm/llvm-project/commit/8b0a24a50da54587b183ff7ff8f7144f6a0aed85.diff

LOG: [MLIR] Vector to XeGPU conversion: Use proper source variant for create_nd_tdesc op creation. (#171216)

If source strided memref is not fully static - at least one of shape,
strides, offset is kDynamic - use i64 source variant.
With this change, xegpu.create_nd_tdesc created by lowering from vector
dialect, can rely on getMixedOffsets, getMixedSize and getMixedStrides
to get relevant values.

Added: 
    

Modified: 
    mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
    mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
    mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
    mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
    mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 079e1e2a8ac67..55ade0ae8eeec 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -102,18 +102,46 @@ static xegpu::CreateNdDescOp createNdDescriptor(PatternRewriter &rewriter,
                                                 xegpu::TensorDescType descType,
                                                 TypedValue<MemRefType> src) {
   MemRefType srcTy = src.getType();
+  assert(srcTy.isStrided() && "Expected strided memref type");
   auto [strides, offset] = srcTy.getStridesAndOffset();
+  bool isStatic = true;
+
+  // Memref is dynamic if any of its shape, offset or strides is dynamic.
+  if (!srcTy.hasStaticShape())
+    isStatic = false;
+
+  if (!ShapedType::isStatic(offset))
+    isStatic = false;
+
+  for (auto stride : strides) {
+    if (!ShapedType::isStatic(stride)) {
+      isStatic = false;
+      break;
+    }
+  }
 
   xegpu::CreateNdDescOp ndDesc;
-  if (srcTy.hasStaticShape()) {
+  if (isStatic) {
     ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src);
   } else {
-    // In case of any dynamic shapes, source's shape and strides have to be
+    // In case of ranked dynamic memref, instead of passing on the memref,
+    // i64 base address, source's offset, shape and strides have to be
     // explicitly provided.
     auto meta = memref::ExtractStridedMetadataOp::create(rewriter, loc, src);
-    ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
-                                           meta.getConstifiedMixedSizes(),
-                                           meta.getConstifiedMixedStrides());
+    auto baseAddrIndex = memref::ExtractAlignedPointerAsIndexOp::create(
+        rewriter, loc, meta.getBaseBuffer());
+    auto offset = meta.getOffset();
+    auto elemByteSize = srcTy.getElementTypeBitWidth() / 8;
+    auto offsetInBytes = arith::MulIOp::create(
+        rewriter, loc, offset,
+        arith::ConstantIndexOp::create(rewriter, loc, elemByteSize));
+    auto adjustedBaseAddr = arith::AddIOp::create(
+        rewriter, loc, baseAddrIndex.getResult(), offsetInBytes);
+    auto adjustedAddrI64 = arith::IndexCastOp::create(
+        rewriter, loc, rewriter.getI64Type(), adjustedBaseAddr);
+    ndDesc = xegpu::CreateNdDescOp::create(
+        rewriter, loc, descType, adjustedAddrI64,
+        meta.getConstifiedMixedSizes(), meta.getConstifiedMixedStrides());
   }
 
   return ndDesc;

diff  --git a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
index ae5141db16c09..c77efa03f3483 100644
--- a/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
@@ -9,10 +9,17 @@ func.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vecto
 // CHECK-LABEL: @load_1D_vector(
 // CHECK-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
 // CHECK-SAME:  %[[OFFSET:.+]]: index
+// CHECK:       %[[ELEM_BYTES:.+]] = arith.constant 4 : index
 // CHECK:       %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0]
-// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME:    %[[COLLAPSED]]
-// CHECK-SAME:    memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32,
+// CHECK:       %[[BASE_BUFFER:.+]], %[[OFFSET1:.+]], %[[SIZES:.+]], %[[STRIDES:.+]] = memref.extract_strided_metadata %[[COLLAPSED]]
+// CHECK-SAME:    : memref<32xf32, strided<[1], offset: ?>> -> memref<f32>, index, index, index
+// CHECK:       %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// CHECK-SAME:    : memref<f32> -> index
+// CHECK:       %[[MUL:.+]] = arith.muli %[[OFFSET1]], %[[ELEM_BYTES]] : index
+// CHECK:       %[[ADD:.+]] = arith.addi %[[INTPTR]], %[[MUL]] : index
+// CHECK:       %[[I64PTR:.+]] = arith.index_cast %[[ADD]] : index to i64
+// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [32],
+// CHECK-SAME:                   strides : [1] : i64  -> !xegpu.tensor_desc<8xf32,
 // CHECK-SAME:    boundary_check = false
 // CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]]]{{.*}}-> vector<8xf32>
 // CHECK:       return %[[VEC]]
@@ -29,10 +36,16 @@ func.func @load_2D_vector(%source: memref<8x16x32xf32>,
 // CHECK-LABEL: @load_2D_vector(
 // CHECK-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
 // CHECK-SAME:  %[[OFFSET:.+]]: index
+// CHECK:       %[[ELEM_BYTES:.+]] = arith.constant 4 : index
 // CHECK:       %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
-// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME:    %[[COLLAPSED]]
-// CHECK-SAME:    memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK:       %[[BASE_BUFFER:.*]], %[[OFF1:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// CHECK:       %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// CHECK-SAME:    : memref<f32> -> index
+// CHECK:       %[[MUL:.+]] = arith.muli %[[OFF1]], %[[ELEM_BYTES]] : index
+// CHECK:       %[[ADD:.+]] = arith.addi %[[INTPTR]], %[[MUL]] : index
+// CHECK:       %[[I64PTR:.+]] = arith.index_cast %[[ADD]] : index to i64
+// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [16, 32],
+// CHECK-SAME:                   strides : [32, 1] : i64 -> !xegpu.tensor_desc<8x16xf32>
 // CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
 // CHECK:       return %[[VEC]]
 
@@ -48,9 +61,15 @@ func.func @load_dynamic_source(%source: memref<?x?x?xf32>,
 // CHECK-LABEL: @load_dynamic_source(
 // CHECK-SAME:  %[[SRC:.+]]: memref<?x?x?xf32>,
 // CHECK-SAME:  %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// CHECK:       %[[ELEM_BYTES:.+]] = arith.constant 4 : index
 // CHECK:       %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
-// CHECK:       {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
-// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
+// CHECK:       %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// CHECK:       %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]] : memref<f32> -> index
+// CHECK:       %[[MUL:.+]] = arith.muli %[[OFFSET]], %[[ELEM_BYTES]] : index
+// CHECK:       %[[ADD:.+]] = arith.addi %[[INTPTR]], %[[MUL]] : index
+// CHECK:       %[[I64PTR:.+]] = arith.index_cast %[[ADD]] : index to i64
+// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [%[[SIZES]]#0, %[[SIZES]]#1],
+// CHECK-SAME:                   strides : [%[[STRIDES]]#0, 1] : i64 -> !xegpu.tensor_desc<8x16xf32>
 // CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF1]], %[[OFF2]]]{{.*}}-> vector<8x16xf32>
 // CHECK:       return %[[VEC]]
 

diff  --git a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
index 1a10d917623cc..3c11313d05536 100644
--- a/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
@@ -11,10 +11,17 @@ func.func @store_1D_vector(%vec: vector<8xf32>,
 // CHECK-SAME:  %[[VEC:.+]]: vector<8xf32>,
 // CHECK-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
 // CHECK-SAME:  %[[OFFSET:.+]]: index
+// CHECK:       %[[ELEM_BYTES:.*]] = arith.constant 4 : index
 // CHECK:       %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0]
-// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME:    %[[COLLAPSED]]
-// CHECK-SAME:    memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32,
+// CHECK:       %[[BASE_BUFFER:.+]], %[[OFFSET1:.+]], %[[SIZES:.+]], %[[STRIDES:.+]] = memref.extract_strided_metadata %[[COLLAPSED]]
+// CHECK-SAME:    : memref<32xf32, strided<[1], offset: ?>> -> memref<f32>, index, index, index
+// CHECK:       %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// CHECK-SAME:    : memref<f32> -> index
+// CHECK:       %[[MUL:.+]] = arith.muli %[[OFFSET1]], %[[ELEM_BYTES]] : index
+// CHECK:       %[[ADD:.+]] = arith.addi %[[INTPTR]], %[[MUL]] : index
+// CHECK:       %[[I64PTR:.+]] = arith.index_cast %[[ADD]] : index to i64
+// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [32],
+// CHECK-SAME:                   strides : [1] : i64  -> !xegpu.tensor_desc<8xf32,
 // CHECK-SAME:    boundary_check = false
 // CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]]] : vector<8xf32>
 
@@ -31,10 +38,16 @@ func.func @store_2D_vector(%vec: vector<8x16xf32>,
 // CHECK-SAME:  %[[VEC:.+]]: vector<8x16xf32>,
 // CHECK-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
 // CHECK-SAME:  %[[OFFSET:.+]]: index
+// CHECK:       %[[ELEM_BYTES:.*]] = arith.constant 4 : index
 // CHECK:       %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
-// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME:    %[[COLLAPSED]]
-// CHECK-SAME:    memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK:       %[[BASE_BUFFER:.*]], %[[OFF1:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// CHECK:       %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// CHECK-SAME:    : memref<f32> -> index
+// CHECK:       %[[MUL:.+]] = arith.muli %[[OFF1]], %[[ELEM_BYTES]] : index
+// CHECK:       %[[ADD:.+]] = arith.addi %[[INTPTR]], %[[MUL]] : index
+// CHECK:       %[[I64PTR:.+]] = arith.index_cast %[[ADD]] : index to i64
+// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [16, 32],
+// CHECK-SAME:                   strides : [32, 1] : i64 -> !xegpu.tensor_desc<8x16xf32>
 // CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
 
 // -----
@@ -50,9 +63,15 @@ func.func @store_dynamic_source(%vec: vector<8x16xf32>,
 // CHECK-SAME:  %[[VEC:.+]]: vector<8x16xf32>,
 // CHECK-SAME:  %[[SRC:.+]]: memref<?x?x?xf32>,
 // CHECK-SAME:  %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// CHECK:       %[[ELEM_BYTES:.*]] = arith.constant 4 : index
 // CHECK:       %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
-// CHECK:       {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
-// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
+// CHECK:       %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// CHECK:       %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]] : memref<f32> -> index
+// CHECK:       %[[MUL:.+]] = arith.muli %[[OFFSET]], %[[ELEM_BYTES]] : index
+// CHECK:       %[[ADD:.+]] = arith.addi %[[INTPTR]], %[[MUL]] : index
+// CHECK:       %[[I64PTR:.+]] = arith.index_cast %[[ADD]] : index to i64
+// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [%[[SIZES]]#0, %[[SIZES]]#1],
+// CHECK-SAME:                   strides : [%[[STRIDES]]#0, 1] : i64 -> !xegpu.tensor
 // CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFF1]], %[[OFF2]]] : vector<8x16xf32>
 
 // -----

diff  --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
index 8bb272b1fe5fc..b58f9b30ed726 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
@@ -48,10 +48,16 @@ gpu.func @load_2D_vector(%source: memref<8x16x32xf32>,
 // LOAD-ND-LABEL:  @load_2D_vector(
 // LOAD-ND-SAME:   %[[SRC:.+]]: memref<8x16x32xf32>,
 // LOAD-ND-SAME:   %[[OFFSET:.+]]: index
+// LOAD-ND:        %[[ELEM_BYTES:.+]] = arith.constant 4 : index
 // LOAD-ND:        %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
-// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc
-// LOAD-ND-SAME:     %[[COLLAPSED]]
-// LOAD-ND-SAME:     memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32,
+// LOAD-ND:        %[[BASE_BUFFER:.*]], %[[OFF1:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// LOAD-ND:        %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// LOAD-ND-SAME:     : memref<f32> -> index
+// LOAD-ND:        %[[MUL:.*]] = arith.muli %[[OFF1]], %[[ELEM_BYTES]] : index
+// LOAD-ND:        %[[ADD:.*]] = arith.addi %[[INTPTR]], %[[MUL]] : index
+// LOAD-ND:        %[[I64PTR:.*]] = arith.index_cast %[[ADD]] : index to i64
+// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [16, 32],
+// LOAD-ND-SAME:                   strides : [32, 1] : i64 -> !xegpu.tensor_desc<8x16xf32,
 // LOAD-ND-SAME:     boundary_check = false
 // LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFFSET]], %[[OFFSET]]]{{.*}}-> vector<8x16xf32>
 // LOAD-ND:        return %[[VEC]]
@@ -147,9 +153,16 @@ gpu.func @load_dynamic_source(%source: memref<?x?x?xf32>,
 // LOAD-ND-LABEL:  @load_dynamic_source(
 // LOAD-ND-SAME:   %[[SRC:.+]]: memref<?x?x?xf32>,
 // LOAD-ND-SAME:   %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// LOAD-ND:        %[[ELEM_BYTES:.+]] = arith.constant 4 : index
 // LOAD-ND:        %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
-// LOAD-ND:        {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
-// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
+// LOAD-ND:        %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// LOAD-ND:        %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]] : memref<f32> -> index
+// LOAD-ND:        %[[MUL:.*]] = arith.muli %[[OFFSET]], %[[ELEM_BYTES]] : index
+// LOAD-ND:        %[[ADD:.*]] = arith.addi %[[INTPTR]], %[[MUL]] : index
+// LOAD-ND:        %[[I64PTR:.*]] = arith.index_cast %[[ADD]] : index to i64
+// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [%[[SIZES]]#0, %[[SIZES]]#1],
+// LOAD-ND-SAME:                    strides : [%[[STRIDES]]#0, 1] : i64 -> !xegpu.tensor_desc<8x16xf32,
+// LOAD-ND-SAME:                      #xegpu.block_tdesc_attr<boundary_check = false>>
 // LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF1]], %[[OFF2]]]{{.*}}-> vector<8x16xf32>
 // LOAD-ND:        return %[[VEC]]
 
@@ -184,8 +197,15 @@ gpu.func @load_dynamic_source2(%source: memref<?x8x16xf32>,
 // LOAD-ND-LABEL:  @load_dynamic_source2(
 // LOAD-ND-SAME:   %[[SRC:.+]]: memref<?x8x16xf32>,
 // LOAD-ND-SAME:   %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// LOAD-ND:        %[[ELEM_BYTES:.+]] = arith.constant 4 : index
 // LOAD-ND:        %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
-// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %{{.*}} : memref<8x16xf32, strided<[16, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
+// LOAD-ND:        %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// LOAD-ND:        %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// LOAD-ND:        %[[MUL:.*]] = arith.muli %[[OFFSET]], %[[ELEM_BYTES]] : index
+// LOAD-ND:        %[[ADD:.*]] = arith.addi %[[INTPTR]], %[[MUL]] : index
+// LOAD-ND:        %[[I64PTR:.*]] = arith.index_cast %[[ADD]] : index to i64
+// LOAD-ND:        %[[DESC:.*]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [8, 16], strides : [16, 1] :
+// LOAD-ND-SAME:                    i64 -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
 // LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%{{.*}}, %{{.*}}] : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>> -> vector<8x16xf32>
 // LOAD-ND:        return %[[VEC]] : vector<8x16xf32>
 
@@ -459,11 +479,15 @@ gpu.func @load_from_subview_2D(%source: memref<4096x4096xf16>, %off1: index, %of
 // LOAD-ND-LABEL:  @load_from_subview_2D(
 // LOAD-ND-SAME:   %[[SRC:.+]]: memref<4096x4096xf16>,
 // LOAD-ND-SAME:   %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// LOAD-ND:        %[[ELEM_BYTES:.+]] = arith.constant 2 : index
 // LOAD-ND:        %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
-// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc
-// LOAD-ND-SAME:     %[[SUBVIEW]]
-// LOAD-ND-SAME:     memref<256x256xf16, strided<[4096, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf16,
-// LOAD-ND-SAME:     boundary_check = false
+// LOAD-ND:        %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[SUBVIEW]]
+// LOAD-ND:        %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// LOAD-ND:        %[[MUL:.*]] = arith.muli %[[OFFSET]], %[[ELEM_BYTES]] : index
+// LOAD-ND:        %[[ADD:.*]] = arith.addi %[[INTPTR]], %[[MUL]] : index
+// LOAD-ND:        %[[I64PTR:.*]] = arith.index_cast %[[ADD]] : index to i64
+// LOAD-ND:        %[[DESC:.*]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [256, 256], strides : [4096, 1] :
+// LOAD-ND-SAME:                    i64 -> !xegpu.tensor_desc<8x16xf16, #xegpu.block_tdesc_attr<boundary_check = false>>
 // LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]][%[[OFF2]], %[[OFF2]]]{{.*}}-> vector<8x16xf16>
 // LOAD-ND:        return %[[VEC]]
 

diff  --git a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
index 43a1a7206e2cc..66da64225678e 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
@@ -15,10 +15,17 @@ gpu.func @store_1D_vector(%vec: vector<8xf32>,
 // STORE-ND-SAME:  %[[VEC:.+]]: vector<8xf32>,
 // STORE-ND-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
 // STORE-ND-SAME:  %[[OFFSET:.+]]: index
+// STORE-ND:       %[[ELEM_BYTES:.+]] = arith.constant 4 : index
 // STORE-ND:       %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], %[[OFFSET]], 0]
-// STORE-ND:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// STORE-ND-SAME:    %[[COLLAPSED]]
-// STORE-ND-SAME:    memref<32xf32, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf32,
+// STORE-ND:       %[[BASE_BUFFER:.+]], %[[OFFSET1:.+]], %[[SIZES:.+]], %[[STRIDES:.+]] = memref.extract_strided_metadata %[[COLLAPSED]]
+// STORE-ND-SAME:    : memref<32xf32, strided<[1], offset: ?>> -> memref<f32>, index, index, index
+// STORE-ND:       %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// STORE-ND-SAME:    : memref<f32> -> index
+// STORE-ND:       %[[MUL:.+]] = arith.muli %[[OFFSET1]], %[[ELEM_BYTES]] : index
+// STORE-ND:       %[[ADD:.+]] = arith.addi %[[INTPTR]], %[[MUL]] : index
+// STORE-ND:       %[[I64PTR:.+]] = arith.index_cast %[[ADD]] : index to i64
+// STORE-ND:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [32],
+// STORE-ND-SAME:                   strides : [1] : i64  -> !xegpu.tensor_desc<8xf32,
 // STORE-ND-SAME:    boundary_check = false
 // STORE-ND:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]]] : vector<8xf32>
 
@@ -50,10 +57,16 @@ gpu.func @store_2D_vector(%vec: vector<8x16xf32>,
 // STORE-ND-SAME:  %[[VEC:.+]]: vector<8x16xf32>,
 // STORE-ND-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
 // STORE-ND-SAME:  %[[OFFSET:.+]]: index
+// STORE-ND:       %[[ELEM_BYTES:.+]] = arith.constant 4 : index
 // STORE-ND:       %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFFSET]], 0, 0]
-// STORE-ND:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// STORE-ND-SAME:    %[[COLLAPSED]]
-// STORE-ND-SAME:    memref<16x32xf32, strided<[32, 1], offset: ?>> -> !xegpu.tensor_desc<8x16xf32,
+// STORE-ND:       %[[BASE_BUFFER:.*]], %[[OFF1:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// STORE-ND:       %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// STORE-ND-SAME:    : memref<f32> -> index
+// STORE-ND:       %[[MUL:.+]] = arith.muli %[[OFF1]], %[[ELEM_BYTES]] : index
+// STORE-ND:       %[[ADD:.+]] = arith.addi %[[INTPTR]], %[[MUL]] : index
+// STORE-ND:       %[[I64PTR:.*]] = arith.index_cast %[[ADD]] : index to i64
+// STORE-ND:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [16, 32],
+// STORE-ND-SAME:                   strides : [32, 1] : i64 -> !xegpu.tensor_desc<8x16xf32,
 // STORE-ND-SAME:    boundary_check = false
 // STORE-ND:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFFSET]], %[[OFFSET]]] : vector<8x16xf32>
 
@@ -86,9 +99,15 @@ gpu.func @store_dynamic_source(%vec: vector<8x16xf32>,
 // STORE-ND-SAME:  %[[VEC:.+]]: vector<8x16xf32>,
 // STORE-ND-SAME:  %[[SRC:.+]]: memref<?x?x?xf32>,
 // STORE-ND-SAME:  %[[OFF0:.+]]: index, %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// STORE-ND:       %[[ELEM_BYTES:.+]] = arith.constant 4 : index
 // STORE-ND:       %[[COLLAPSED:.+]] = memref.subview %[[SRC]][%[[OFF0]], 0, 0]
-// STORE-ND:       {{.*}} %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
-// STORE-ND:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[COLLAPSED]]
+// STORE-ND:       %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.+]]:2, %[[STRIDES:.+]]:2 = memref.extract_strided_metadata %[[COLLAPSED]]
+// STORE-ND:       %[[INTPTR:.+]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]] : memref<f32> -> index
+// STORE-ND:       %[[MUL:.+]] = arith.muli %[[OFFSET]], %[[ELEM_BYTES]] : index
+// STORE-ND:       %[[ADD:.+]] = arith.addi %[[INTPTR]], %[[MUL]] : index
+// STORE-ND:       %[[I64PTR:.*]] = arith.index_cast %[[ADD]] : index to i64
+// STORE-ND:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [%[[SIZES]]#0, %[[SIZES]]#1],
+// STORE-ND-SAME:                   strides : [%[[STRIDES]]#0, 1] : i64 -> !xegpu.tensor
 // STORE-ND:       xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFF1]], %[[OFF2]]] : vector<8x16xf32>
 
 // STORE-SCATTER-LABEL: @store_dynamic_source(
@@ -293,12 +312,16 @@ gpu.func @store_to_subview(%vec: vector<8xf16>,
 // STORE-ND-SAME:   %[[VEC:.+]]: vector<8xf16>,
 // STORE-ND-SAME:   %[[SRC:.+]]: memref<4096x4096xf16>,
 // STORE-ND-SAME:   %[[OFF1:.+]]: index, %[[OFF2:.+]]: index
+// STORE-ND:        %[[ELEM_BYTES:.+]] = arith.constant 2 : index
 // STORE-ND:        %[[SUBVIEW:.+]] = memref.subview %[[SRC]][%[[OFF1]], %[[OFF2]]] [256, 256] [1, 1] : memref<4096x4096xf16> to memref<256x256xf16, strided<[4096, 1], offset: ?>>
 // STORE-ND:        %[[COLLAPSED:.+]] = memref.subview %[[SUBVIEW]][%[[OFF2]], 0]
-// STORE-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc
-// STORE-ND-SAME:     %[[COLLAPSED]]
-// STORE-ND-SAME:     memref<256xf16, strided<[1], offset: ?>> -> !xegpu.tensor_desc<8xf16,
-// STORE-ND-SAME:     boundary_check = false
+// STORE-ND:        %[[BASE_BUFFER:.*]], %[[OFFSET:.*]], %[[SIZES:.*]], %[[STRIDES:.*]] = memref.extract_strided_metadata %[[COLLAPSED]]
+// STORE-ND:        %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[BASE_BUFFER]]
+// STORE-ND:        %[[MUL:.+]] = arith.muli %[[OFFSET]], %[[ELEM_BYTES]] : index
+// STORE-ND:        %[[ADD:.+]] = arith.addi %[[INTPTR]], %[[MUL]] : index
+// STORE-ND:        %[[I64PTR:.*]] = arith.index_cast %[[ADD]] : index to i64
+// STORE-ND:        %[[DESC:.*]] = xegpu.create_nd_tdesc %[[I64PTR]], shape : [256], strides : [1] : i64 ->
+// STORE-ND-SAME:                    !xegpu.tensor_desc<8xf16, #xegpu.block_tdesc_attr<boundary_check = false>>
 // STORE-ND:        xegpu.store_nd %[[VEC]], %[[DESC]][%[[OFF2]]] : vector<8xf16>
 
 // STORE-SCATTER-LABEL:  @store_to_subview(


        


More information about the Mlir-commits mailing list