[Mlir-commits] [mlir] [MLIR][Conversion] XeGPU to XeVM: create_nd_tdesc - Add support for base memory rank > 2 (PR #164701)
    Sang Ik Lee 
    llvmlistbot at llvm.org
       
    Wed Oct 22 13:05:57 PDT 2025
    
    
  
https://github.com/silee2 created https://github.com/llvm/llvm-project/pull/164701
create_nd_tdesc is currently lowered to a fixed size vector that encodes 2D shape and strides for base memory.
Supporting base memory rank > 2, requires a different approach.
Consumers of create_nd_tdesc op - load_nd, store_nd, prefetch_nd now gets base memory information directly from create_nd_tdesc instead of going through the fixed payload.
>From 4a92953e944f31c80b047c0e90f050cbb4d983e9 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Tue, 21 Oct 2025 22:13:38 +0000
Subject: [PATCH 1/9] [MLIR][Conversion] XeGPU to XeVM: Create nd tensor
 descriptor payload for base memory rank > 2
---
 .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp    | 82 +++++++++++++++----
 1 file changed, 66 insertions(+), 16 deletions(-)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index fcbf66dbe9e45..bb56c096879a3 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -151,6 +151,21 @@ translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
   }
 }
 
+// Compute the product of sizes in the range [lo, hi) from the sizes array.
+static Value getProductOfSizes(ConversionPatternRewriter &rewriter,
+                               Location loc, ArrayRef<OpFoldResult> sizes,
+                               size_t lo, size_t hi) {
+  Type indexTy = rewriter.getIndexType();
+  Value product = arith::ConstantIndexOp::create(rewriter, loc, 1);
+  for (size_t idx = lo; idx < hi; idx++) {
+    OpFoldResult ofr = sizes[idx];
+    Value sizeVal = getValueOrCreateConstantIntOp(rewriter, loc, ofr);
+    sizeVal = getValueOrCreateCastToIndexLike(rewriter, loc, indexTy, sizeVal);
+    product = rewriter.createOrFold<arith::MulIOp>(loc, product, sizeVal);
+  }
+  return product;
+}
+
 class CreateNdDescToXeVMPattern
     : public OpConversionPattern<xegpu::CreateNdDescOp> {
   using OpConversionPattern::OpConversionPattern;
@@ -184,10 +199,9 @@ class CreateNdDescToXeVMPattern
 
     // Source can be a memref or a pointer (ui64, ui32, i64 or i32).
     SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
-    // Descriptor shape is expected to be 2D.
-    int64_t rank = mixedSizes.size();
-    if (rank != 2)
-      return rewriter.notifyMatchFailure(op, "Expected 2D shape.");
+    auto srcRank = mixedSizes.size();
+    if (srcRank < 2)
+      return rewriter.notifyMatchFailure(op, "Expected at least 2D source.");
 
     auto sourceTy = source.getType();
     auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
@@ -203,9 +217,8 @@ class CreateNdDescToXeVMPattern
       baseAddr = adaptor.getSource();
     }
     // Utility for creating offset values from op fold result.
-    auto createOffset = [&](SmallVector<OpFoldResult> &ofrVec,
-                            unsigned idx) -> Value {
-      Value val = getValueOrCreateConstantIntOp(rewriter, loc, ofrVec[idx]);
+    auto createOffset = [&](OpFoldResult ofr) -> Value {
+      Value val = getValueOrCreateConstantIntOp(rewriter, loc, ofr);
       val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val);
       return val;
     };
@@ -213,8 +226,14 @@ class CreateNdDescToXeVMPattern
     offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
     offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
     // Get shape values from op fold results.
-    baseShapeW = createOffset(mixedSizes, 1);
-    baseShapeH = createOffset(mixedSizes, 0);
+    baseShapeW = createOffset(mixedSizes[srcRank - 1]);
+    if (srcRank == 2) {
+      baseShapeH = createOffset(mixedSizes[0]);
+    } else {
+      // Generate compute chain for height (product of sizes of all but the last
+      // dimension).
+      baseShapeH = getProductOfSizes(rewriter, loc, mixedSizes, 0, srcRank - 1);
+    }
     if (sourceMemrefTy) {
       // Cast index to i64.
       baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr);
@@ -255,10 +274,18 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
   LogicalResult
   matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    auto tdVal = op.getTensorDesc();
+    xegpu::CreateNdDescOp descOp =
+        tdVal.template getDefiningOp<xegpu::CreateNdDescOp>();
+    auto mixedStrides = descOp.getMixedStrides();
     auto mixedOffsets = op.getMixedOffsets();
-    int64_t opOffsetsSize = mixedOffsets.size();
-    if (opOffsetsSize != 2)
-      return rewriter.notifyMatchFailure(op, "Expected 2D offsets.");
+    auto mixedSizes = descOp.getMixedSizes();
+    size_t opOffsetsSize = mixedOffsets.size();
+    if (opOffsetsSize != mixedStrides.size())
+      return rewriter.notifyMatchFailure(
+          op, "Offsets size should match base memory rank.");
+    if (opOffsetsSize < 2)
+      return rewriter.notifyMatchFailure(op, "Expected at least 2D offset.");
     auto loc = op.getLoc();
     auto ctxt = rewriter.getContext();
 
@@ -283,12 +310,35 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
         rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH));
     // Offsets are provided by the op.
     // convert them to i32.
-    Value offsetW =
-        getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]);
+    // Offset computation assumes base memory layout is row major.
+    Value offsetW = getValueOrCreateConstantIntOp(
+        rewriter, loc, mixedOffsets[opOffsetsSize - 1]);
     offsetW = getValueOrCreateCastToIndexLike(rewriter, loc,
                                               rewriter.getI32Type(), offsetW);
-    Value offsetH =
-        getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
+    Value offsetH;
+    if (opOffsetsSize == 2)
+      offsetH = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
+    else {
+      offsetH = arith::ConstantIndexOp::create(rewriter, loc, 0);
+      Value tmpStride = arith::ConstantIndexOp::create(rewriter, loc, 1);
+      // offsetH requires computing the linear offset using the strides.
+      for (size_t idx = 0; idx < opOffsetsSize - 1; idx++) {
+        size_t revIdx = opOffsetsSize - 2 - idx;
+        Value offsetVal =
+            getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[revIdx]);
+        offsetVal = getValueOrCreateCastToIndexLike(
+            rewriter, loc, rewriter.getIndexType(), offsetVal);
+        Value mul =
+            rewriter.createOrFold<arith::MulIOp>(loc, tmpStride, offsetVal);
+        Value dimSize =
+            getValueOrCreateConstantIntOp(rewriter, loc, mixedSizes[revIdx]);
+        dimSize = getValueOrCreateCastToIndexLike(
+            rewriter, loc, rewriter.getIndexType(), dimSize);
+        tmpStride =
+            rewriter.createOrFold<arith::MulIOp>(loc, tmpStride, dimSize);
+        offsetH = rewriter.createOrFold<arith::AddIOp>(loc, offsetH, mul);
+      }
+    }
     offsetH = getValueOrCreateCastToIndexLike(rewriter, loc,
                                               rewriter.getI32Type(), offsetH);
     // Get address space from tensor descriptor memory space.
>From e510643ae646688ce0c7175a7459ce1000bd4523 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Tue, 21 Oct 2025 23:00:44 +0000
Subject: [PATCH 2/9] Fix bugs and add test case for high rank base memref.
---
 .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp    |  8 +++---
 .../loadstore_nd_high_base_rank.mlir          | 25 +++++++++++++++++++
 2 files changed, 30 insertions(+), 3 deletions(-)
 create mode 100644 mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_high_base_rank.mlir
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index bb56c096879a3..a048fc0d091e6 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -152,15 +152,15 @@ translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
 }
 
 // Compute the product of sizes in the range [lo, hi) from the sizes array.
+// Note: all sizes are i64.
 static Value getProductOfSizes(ConversionPatternRewriter &rewriter,
                                Location loc, ArrayRef<OpFoldResult> sizes,
                                size_t lo, size_t hi) {
-  Type indexTy = rewriter.getIndexType();
-  Value product = arith::ConstantIndexOp::create(rewriter, loc, 1);
+  Value product =
+      arith::ConstantIntOp::create(rewriter, loc, rewriter.getI64Type(), 1);
   for (size_t idx = lo; idx < hi; idx++) {
     OpFoldResult ofr = sizes[idx];
     Value sizeVal = getValueOrCreateConstantIntOp(rewriter, loc, ofr);
-    sizeVal = getValueOrCreateCastToIndexLike(rewriter, loc, indexTy, sizeVal);
     product = rewriter.createOrFold<arith::MulIOp>(loc, product, sizeVal);
   }
   return product;
@@ -233,6 +233,8 @@ class CreateNdDescToXeVMPattern
       // Generate compute chain for height (product of sizes of all but the last
       // dimension).
       baseShapeH = getProductOfSizes(rewriter, loc, mixedSizes, 0, srcRank - 1);
+      baseShapeH = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy,
+                                                   baseShapeH);
     }
     if (sourceMemrefTy) {
       // Cast index to i64.
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_high_base_rank.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_high_base_rank.mlir
new file mode 100644
index 0000000000000..60d1dcf18d634
--- /dev/null
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_high_base_rank.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
+
+gpu.module @load_store_check {
+    // CHECK: fail
+    gpu.func @load_store(%src: memref<3x3x8x16xf32, 1>, %dst: memref<3x3x8x16xf32, 1>) kernel {
+        %srcce = memref.memory_space_cast %src : memref<3x3x8x16xf32, 1> to memref<3x3x8x16xf32>
+        %dstte = memref.memory_space_cast %dst : memref<3x3x8x16xf32, 1> to memref<3x3x8x16xf32>
+
+        %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<3x3x8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+
+        %loaded = xegpu.load_nd %src_tdesc[2, 2, 0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+            : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
+
+        %tid_x = gpu.thread_id x
+        %tid_x_i32 = arith.index_cast %tid_x : index to i32
+        %tid_x_f32 = arith.sitofp %tid_x_i32 : i32 to f32
+        %loaded_modified = vector.insert %tid_x_f32, %loaded[0] : f32 into vector<8xf32>
+
+        %dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<3x3x8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
+
+        xegpu.store_nd %loaded_modified, %dst_tdesc[1, 1, 0, 0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
+            : vector<8xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
+        gpu.return
+    }
+}
>From 4e4cbd06de5cebb304cbc2ef64421c52a20b24ac Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Wed, 22 Oct 2025 00:27:09 +0000
Subject: [PATCH 3/9] Replace 2D block load payload with i64.
---
 .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp    | 131 ++++--------------
 .../Conversion/XeGPUToXeVM/loadstore_nd.mlir  |   2 +-
 .../loadstore_nd_high_base_rank.mlir          |   2 +-
 .../XeGPUToXeVM/loadstore_nd_int_addr.mlir    |  75 ++++++++++
 .../Conversion/XeGPUToXeVM/prefetch_nd.mlir   |   6 +-
 5 files changed, 106 insertions(+), 110 deletions(-)
 create mode 100644 mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_int_addr.mlir
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index a048fc0d091e6..d7db65d662faf 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -48,15 +48,6 @@ namespace {
 static constexpr int32_t systolicDepth{8};
 static constexpr int32_t executionSize{16};
 
-// Offsets to individual fields of the 8xi32 layout nd tensor descriptor.
-enum class NdTdescOffset : uint32_t {
-  BasePtr = 0,       // Base pointer (i64)
-  BaseShapeW = 2,    // Base shape width (i32)
-  BaseShapeH = 3,    // Base shape height (i32)
-  TensorOffsetW = 4, // Tensor offset W (i32)
-  TensorOffsetH = 5  // Tensor offset H (i32)
-};
-
 static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
   switch (xeGpuMemspace) {
   case xegpu::MemorySpace::Global:
@@ -177,92 +168,14 @@ class CreateNdDescToXeVMPattern
     if (mixedOffsets.size() != 0)
       return rewriter.notifyMatchFailure(op, "Offsets not supported.");
     auto loc = op.getLoc();
-    auto source = op.getSource();
-    // Op is lowered to a code sequence that populates payload.
-    // Payload is a 8xi32 vector. Offset to individual fields are defined in
-    // NdTdescOffset enum.
-    Type payloadElemTy = rewriter.getI32Type();
-    VectorType payloadTy = VectorType::get(8, payloadElemTy);
-    Type i64Ty = rewriter.getI64Type();
-    // 4xi64 view is used for inserting the base pointer.
-    VectorType payloadI64Ty = VectorType::get(4, i64Ty);
-    // Initialize payload to zero.
-    Value payload = arith::ConstantOp::create(
-        rewriter, loc,
-        DenseElementsAttr::get(payloadTy, IntegerAttr::get(payloadElemTy, 0)));
-
-    Value baseAddr;
-    Value baseShapeW;
-    Value baseShapeH;
-    Value offsetW;
-    Value offsetH;
 
-    // Source can be a memref or a pointer (ui64, ui32, i64 or i32).
-    SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
-    auto srcRank = mixedSizes.size();
-    if (srcRank < 2)
-      return rewriter.notifyMatchFailure(op, "Expected at least 2D source.");
-
-    auto sourceTy = source.getType();
-    auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
-    // If source is a memref, we need to extract the aligned pointer as index.
-    // Pointer type is passed as i32 or i64 by type converter.
-    if (sourceMemrefTy) {
-      if (!sourceMemrefTy.hasStaticShape()) {
-        return rewriter.notifyMatchFailure(op, "Expected static memref shape.");
-      }
-      baseAddr =
-          memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source);
-    } else {
-      baseAddr = adaptor.getSource();
-    }
-    // Utility for creating offset values from op fold result.
-    auto createOffset = [&](OpFoldResult ofr) -> Value {
-      Value val = getValueOrCreateConstantIntOp(rewriter, loc, ofr);
-      val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val);
-      return val;
-    };
-    // Offsets are not supported (0 is used).
-    offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
-    offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
-    // Get shape values from op fold results.
-    baseShapeW = createOffset(mixedSizes[srcRank - 1]);
-    if (srcRank == 2) {
-      baseShapeH = createOffset(mixedSizes[0]);
-    } else {
-      // Generate compute chain for height (product of sizes of all but the last
-      // dimension).
-      baseShapeH = getProductOfSizes(rewriter, loc, mixedSizes, 0, srcRank - 1);
-      baseShapeH = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy,
-                                                   baseShapeH);
-    }
-    if (sourceMemrefTy) {
-      // Cast index to i64.
-      baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr);
-    } else if (baseAddr.getType() != i64Ty) {
+    Value baseAddr = adaptor.getSource();
+    Type i64Ty = rewriter.getI64Type();
+    if (baseAddr.getType() != i64Ty) {
       // Pointer type may be i32. Cast to i64 if needed.
       baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr);
     }
-    // Populate payload.
-    Value payLoadAsI64 =
-        vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
-    payLoadAsI64 =
-        vector::InsertOp::create(rewriter, loc, baseAddr, payLoadAsI64,
-                                 static_cast<int>(NdTdescOffset::BasePtr));
-    payload = vector::BitCastOp::create(rewriter, loc, payloadTy, payLoadAsI64);
-    payload =
-        vector::InsertOp::create(rewriter, loc, baseShapeW, payload,
-                                 static_cast<int>(NdTdescOffset::BaseShapeW));
-    payload =
-        vector::InsertOp::create(rewriter, loc, baseShapeH, payload,
-                                 static_cast<int>(NdTdescOffset::BaseShapeH));
-    payload = vector::InsertOp::create(
-        rewriter, loc, offsetW, payload,
-        static_cast<int>(NdTdescOffset::TensorOffsetW));
-    payload = vector::InsertOp::create(
-        rewriter, loc, offsetH, payload,
-        static_cast<int>(NdTdescOffset::TensorOffsetH));
-    rewriter.replaceOp(op, payload);
+    rewriter.replaceOp(op, baseAddr);
     return success();
   }
 };
@@ -291,7 +204,6 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
     auto loc = op.getLoc();
     auto ctxt = rewriter.getContext();
 
-    auto tdesc = adaptor.getTensorDesc();
     auto tdescTy = op.getTensorDescType();
     if (tdescTy.getRank() != 2)
       return rewriter.notifyMatchFailure(op, "Expected 2D tensor descriptor.");
@@ -301,15 +213,27 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
       return rewriter.notifyMatchFailure(
           op, "Expected element type bit width to be multiple of 8.");
 
-    VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
-    Value payLoadAsI64 =
-        vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc);
-    Value basePtr = vector::ExtractOp::create(
-        rewriter, loc, payLoadAsI64, static_cast<int>(NdTdescOffset::BasePtr));
-    Value baseShapeW = vector::ExtractOp::create(
-        rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW));
-    Value baseShapeH = vector::ExtractOp::create(
-        rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH));
+    Value basePtr = adaptor.getTensorDesc();
+    // Utility for creating offset values from op fold result.
+    Type payloadElemTy = rewriter.getIntegerType(32);
+    auto createOffset = [&](OpFoldResult ofr) -> Value {
+      Value val = getValueOrCreateConstantIntOp(rewriter, loc, ofr);
+      val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val);
+      return val;
+    };
+    auto srcRank = mixedSizes.size();
+    // Get shape values from op fold results.
+    Value baseShapeW = createOffset(mixedSizes[srcRank - 1]);
+    Value baseShapeH;
+    if (srcRank == 2) {
+      baseShapeH = createOffset(mixedSizes[0]);
+    } else {
+      // Generate compute chain for height (product of sizes of all but the last
+      // dimension).
+      baseShapeH = getProductOfSizes(rewriter, loc, mixedSizes, 0, srcRank - 1);
+      baseShapeH = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy,
+                                                   baseShapeH);
+    }
     // Offsets are provided by the op.
     // convert them to i32.
     // Offset computation assumes base memory layout is row major.
@@ -979,10 +903,7 @@ struct ConvertXeGPUToXeVMPass
       return VectorType::get(sum, elemType);
     });
     typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type {
-      if (type.isScattered())
-        return IntegerType::get(&getContext(), 64);
-      auto i32Type = IntegerType::get(&getContext(), 32);
-      return VectorType::get(8, i32Type);
+      return IntegerType::get(&getContext(), 64);
     });
     // Convert MemDescType into flattened MemRefType for SLM
     typeConverter.addConversion([&](xegpu::MemDescType type) -> Type {
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
index 4c6bbf25b4728..e4b206842e069 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
+// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s
 
 gpu.module @load_store_check {
     gpu.func @load_store(%src: memref<8x16xf32, 1>, %dst: memref<8x16xf32, 1>) kernel {
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_high_base_rank.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_high_base_rank.mlir
index 60d1dcf18d634..e328517634b03 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_high_base_rank.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_high_base_rank.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
+// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s
 
 gpu.module @load_store_check {
     // CHECK: fail
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_int_addr.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_int_addr.mlir
new file mode 100644
index 0000000000000..ca458eaf231c0
--- /dev/null
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_int_addr.mlir
@@ -0,0 +1,75 @@
+// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s
+
+gpu.module @load_store_check {
+    gpu.func @load_store(%src: ui64, %dst: ui32) kernel {
+        // CHECK: %[[LD_PTR_AS_I64:.*]] = arith.index_castui {{.*}} : index to i64
+        // CHECK: %[[LD_CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64>
+        // CHECK: %[[LD_DESC_0:.*]] = vector.insert %[[LD_PTR_AS_I64]], %[[LD_CREATE_DESC_I64]] [0] : i64 into vector<4xi64>
+        // CHECK: %[[LD_DESC_1:.*]] = vector.bitcast %[[LD_DESC_0]] : vector<4xi64> to vector<8xi32>
+        // CHECK: %[[LD_DESC_2:.*]] = vector.insert {{.*}}, %[[LD_DESC_1]] [2] : i32 into vector<8xi32>
+        // CHECK: %[[LD_DESC_3:.*]] = vector.insert {{.*}}, %[[LD_DESC_2]] [3] : i32 into vector<8xi32>
+        // CHECK: %[[LD_DESC_4:.*]] = vector.insert {{.*}}, %[[LD_DESC_3]] [4] : i32 into vector<8xi32>
+        // CHECK: %[[LD_DESC:.*]] = vector.insert {{.*}}, %[[LD_DESC_4]] [5] : i32 into vector<8xi32>
+        %c8 = arith.constant 8 : index
+        %c16 = arith.constant 16 : index
+        %c1 = arith.constant 1 : index
+        %src_tdesc = xegpu.create_nd_tdesc %src, shape:[%c8, %c16], strides:[%c16, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32>
+
+
+        //CHECK: %[[LD_DESC_I64:.*]] = vector.bitcast %[[LD_DESC]] : vector<8xi32> to vector<4xi64>
+        //CHECK: %[[LD_INTPTR:.*]] = vector.extract %[[LD_DESC_I64]][0] : i64 from vector<4xi64>
+        //CHECK: %[[LD_BASE_W:.*]] = vector.extract %[[LD_DESC]][2] : i32 from vector<8xi32>
+        //CHECK: %[[LD_BASE_H:.*]] = vector.extract %[[LD_DESC]][3] : i32 from vector<8xi32>
+        //CHECK: %[[LD_TILE_W64:.*]] = arith.constant 0 : i64
+        //CHECK: %[[LD_TILE_W:.*]] = arith.trunci %[[LD_TILE_W64]] : i64 to i32
+        //CHECK: %[[LD_TILE_H64:.*]] = arith.constant 0 : i64
+        //CHECK: %[[LD_TILE_H:.*]] = arith.trunci %[[LD_TILE_H64]] : i64 to i32
+        //CHECK: %[[LD_LLVMPTR:.*]] = llvm.inttoptr %[[LD_INTPTR]] : i64 to !llvm.ptr<1>
+        //CHECK: %[[LD_SIZEOF_F32:.*]] = arith.constant 4 : i32
+        //CHECK: %[[LD_BASE_ROW_IN_BYTES:.*]] = arith.muli %[[LD_BASE_W]], %[[LD_SIZEOF_F32]] : i32
+        //CHECK: %[[LD_LOADED_I32:.*]] = xevm.blockload2d %[[LD_LLVMPTR]], %[[LD_BASE_ROW_IN_BYTES]],
+        //CHECK-SAME: %[[LD_BASE_H]], %[[LD_BASE_ROW_IN_BYTES]], %[[LD_TILE_W]], %[[LD_TILE_H]]
+        //CHECK-SAME: <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
+        //CHECK-SAME:   pack_register = false, tile_height = 8 : i32, tile_width = 16 : i32, transpose = false,
+        //CHECK-SAME:   v_blocks = 1 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
+        %loaded = xegpu.load_nd %src_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+            : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
+        //CHECK: %[[LD_LOADED_F32:.*]] = vector.bitcast %[[LD_LOADED_I32]] : vector<8xi32> to vector<8xf32>
+
+        %tid_x = gpu.thread_id x
+        %tid_x_i32 = arith.index_cast %tid_x : index to i32
+        %tid_x_f32 = arith.sitofp %tid_x_i32 : i32 to f32
+        //CHECK: %[[LOADED_F32_MODIFIED:.*]] = vector.insert %{{.*}}, %[[LD_LOADED_F32]] [0] : f32 into vector<8xf32>
+        %loaded_modified = vector.insert %tid_x_f32, %loaded[0] : f32 into vector<8xf32>
+
+        // CHECK: %[[PTR_AS_I64:.*]] = arith.index_castui {{.*}} : index to i64
+        // CHECK: %[[CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64>
+        // CHECK: %[[DESC_0:.*]] = vector.insert %[[PTR_AS_I64]], %[[CREATE_DESC_I64]] [0] : i64 into vector<4xi64>
+        // CHECK: %[[DESC_1:.*]] = vector.bitcast %[[DESC_0]] : vector<4xi64> to vector<8xi32>
+        // CHECK: %[[DESC_2:.*]] = vector.insert {{.*}}, %[[DESC_1]] [2] : i32 into vector<8xi32>
+        // CHECK: %[[DESC_3:.*]] = vector.insert {{.*}}, %[[DESC_2]] [3] : i32 into vector<8xi32>
+        // CHECK: %[[DESC_4:.*]] = vector.insert {{.*}}, %[[DESC_3]] [4] : i32 into vector<8xi32>
+        // CHECK: %[[DESC:.*]] = vector.insert {{.*}}, %[[DESC_4]] [5] : i32 into vector<8xi32>
+        %dst_tdesc = xegpu.create_nd_tdesc %dst, shape:[%c8, %c16], strides:[%c16, %c1] : ui32 -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
+
+        //CHECK: %[[DESC_I64:.*]] = vector.bitcast %[[DESC]] : vector<8xi32> to vector<4xi64>
+        //CHECK: %[[INTPTR:.*]] = vector.extract %[[DESC_I64]][0] : i64 from vector<4xi64>
+        //CHECK: %[[BASE_W:.*]] = vector.extract %[[DESC]][2] : i32 from vector<8xi32>
+        //CHECK: %[[BASE_H:.*]] = vector.extract %[[DESC]][3] : i32 from vector<8xi32>
+        //CHECK: %[[TILE_W64:.*]] = arith.constant 0 : i64
+        //CHECK: %[[TILE_W:.*]] = arith.trunci %[[TILE_W64]] : i64 to i32
+        //CHECK: %[[TILE_H64:.*]] = arith.constant 0 : i64
+        //CHECK: %[[TILE_H:.*]] = arith.trunci %[[TILE_H64]] : i64 to i32
+        //CHECK: %[[LLVMPTR:.*]] = llvm.inttoptr %[[INTPTR]] : i64 to !llvm.ptr<1>
+        //CHECK: %[[SIZEOF_F32:.*]] = arith.constant 4 : i32
+        //CHECK: %[[BASE_ROW_IN_BYTES:.*]] = arith.muli %[[BASE_W]], %[[SIZEOF_F32]] : i32
+        //CHECK: %[[FLAT_VALUE_I32:.*]] = vector.bitcast %[[LOADED_F32_MODIFIED]] : vector<8xf32> to vector<8xi32>
+        //CHECK: xevm.blockstore2d %[[LLVMPTR]], %[[BASE_ROW_IN_BYTES]], %[[BASE_H]], %[[BASE_ROW_IN_BYTES]],
+        //CHECK-SAME: %[[TILE_W]], %[[TILE_H]], %[[FLAT_VALUE_I32]]
+        //CHECK-SAME: <{cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
+        //CHECK-SAME:   tile_height = 8 : i32, tile_width = 16 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
+        xegpu.store_nd %loaded_modified, %dst_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
+            : vector<8xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
+        gpu.return
+    }
+}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir
index 873478aed57e3..ae87951a33447 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir
@@ -1,7 +1,7 @@
-// RUN: mlir-opt -convert-xegpu-to-xevm -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
 
-gpu.module @fence_check {
-    gpu.func @fence(%src: memref<8x16xf32, 1>, %dst: memref<8x16xf32, 1>) kernel {
+gpu.module @prefetch_nd_check {
+    gpu.func @prefetch_nd(%src: memref<8x16xf32, 1>, %dst: memref<8x16xf32, 1>) kernel {
         %srcce = memref.memory_space_cast %src : memref<8x16xf32, 1> to memref<8x16xf32>
         %dstte = memref.memory_space_cast %dst : memref<8x16xf32, 1> to memref<8x16xf32>
 
>From 2546a37b2a955eb0a3e6133f458bddc7c6e4b2d0 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Wed, 22 Oct 2025 10:41:20 -0700
Subject: [PATCH 4/9] Update test check.
---
 .../Conversion/XeGPUToXeVM/loadstore_nd.mlir  | 74 ++++++-------------
 1 file changed, 23 insertions(+), 51 deletions(-)
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
index e4b206842e069..0764129cfd447 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
@@ -1,73 +1,45 @@
 // RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s
 
 gpu.module @load_store_check {
+    // CHECK-LABEL: gpu.func @load_store
+    // CHECK-SAME: %[[ARG0:.*]]: memref<8x16xf32, 1>, %[[ARG1:.*]]: memref<8x16xf32, 1>
     gpu.func @load_store(%src: memref<8x16xf32, 1>, %dst: memref<8x16xf32, 1>) kernel {
+        // CHECK: %[[C64_i32:.*]] = arith.constant 64 : i32
+        // CHECK: %[[C0_i32:.*]] = arith.constant 0 : i32
+        // CHECK: %[[C8_i32:.*]] = arith.constant 8 : i32
+        // CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]]
+        // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST:.*]] : memref<8x16xf32> -> index
+        // CHECK: %[[VAR0:.*]] = arith.index_castui %[[INTPTR]] : index to i64
         %srcce = memref.memory_space_cast %src : memref<8x16xf32, 1> to memref<8x16xf32>
+        // CHECK: %[[MEMSPACECAST_0:.*]] = memref.memory_space_cast %[[ARG1]]
+        // CHECK: %[[INTPTR_1:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST_0]] : memref<8x16xf32> -> index
+        // CHECK: %[[VAR1:.*]] = arith.index_castui %[[INTPTR_1:.*]] : index to i64
         %dstte = memref.memory_space_cast %dst : memref<8x16xf32, 1> to memref<8x16xf32>
 
-        // CHECK: %[[LD_PTR_AS_I64:.*]] = arith.index_castui {{.*}} : index to i64
-        // CHECK: %[[LD_CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64>
-        // CHECK: %[[LD_DESC_0:.*]] = vector.insert %[[LD_PTR_AS_I64]], %[[LD_CREATE_DESC_I64]] [0] : i64 into vector<4xi64>
-        // CHECK: %[[LD_DESC_1:.*]] = vector.bitcast %[[LD_DESC_0]] : vector<4xi64> to vector<8xi32>
-        // CHECK: %[[LD_DESC_2:.*]] = vector.insert {{.*}}, %[[LD_DESC_1]] [2] : i32 into vector<8xi32>
-        // CHECK: %[[LD_DESC_3:.*]] = vector.insert {{.*}}, %[[LD_DESC_2]] [3] : i32 into vector<8xi32>
-        // CHECK: %[[LD_DESC_4:.*]] = vector.insert {{.*}}, %[[LD_DESC_3]] [4] : i32 into vector<8xi32>
-        // CHECK: %[[LD_DESC:.*]] = vector.insert {{.*}}, %[[LD_DESC_4]] [5] : i32 into vector<8xi32>
         %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
 
-
-        //CHECK: %[[LD_DESC_I64:.*]] = vector.bitcast %[[LD_DESC]] : vector<8xi32> to vector<4xi64>
-        //CHECK: %[[LD_INTPTR:.*]] = vector.extract %[[LD_DESC_I64]][0] : i64 from vector<4xi64>
-        //CHECK: %[[LD_BASE_W:.*]] = vector.extract %[[LD_DESC]][2] : i32 from vector<8xi32>
-        //CHECK: %[[LD_BASE_H:.*]] = vector.extract %[[LD_DESC]][3] : i32 from vector<8xi32>
-        //CHECK: %[[LD_TILE_W64:.*]] = arith.constant 0 : i64
-        //CHECK: %[[LD_TILE_W:.*]] = arith.trunci %[[LD_TILE_W64]] : i64 to i32
-        //CHECK: %[[LD_TILE_H64:.*]] = arith.constant 0 : i64
-        //CHECK: %[[LD_TILE_H:.*]] = arith.trunci %[[LD_TILE_H64]] : i64 to i32
-        //CHECK: %[[LD_LLVMPTR:.*]] = llvm.inttoptr %[[LD_INTPTR]] : i64 to !llvm.ptr<1>
-        //CHECK: %[[LD_SIZEOF_F32:.*]] = arith.constant 4 : i32
-        //CHECK: %[[LD_BASE_ROW_IN_BYTES:.*]] = arith.muli %[[LD_BASE_W]], %[[LD_SIZEOF_F32]] : i32
-        //CHECK: %[[LD_LOADED_I32:.*]] = xevm.blockload2d %[[LD_LLVMPTR]], %[[LD_BASE_ROW_IN_BYTES]],
-        //CHECK-SAME: %[[LD_BASE_H]], %[[LD_BASE_ROW_IN_BYTES]], %[[LD_TILE_W]], %[[LD_TILE_H]]
-        //CHECK-SAME: <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
-        //CHECK-SAME:   pack_register = false, tile_height = 8 : i32, tile_width = 16 : i32, transpose = false,
-        //CHECK-SAME:   v_blocks = 1 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
+        // CHECK: %[[VAR2:.*]] = llvm.inttoptr %[[VAR0]] : i64 to !llvm.ptr<1>
+        // CHECK: %[[VAR3:.*]] = xevm.blockload2d %[[VAR2]], %[[C64_i32]], %[[C8_i32]], %[[C64_i32]],
+        // CHECK-SAME:  %[[C0_i32]], %[[C0_i32]] <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>,
+        // CHECK-SAME:  elem_size_in_bits = 32 : i32, pack_register = false, tile_height = 8 : i32,
+        // CHECK-SAME:  tile_width = 16 : i32, transpose = false, v_blocks = 1 : i32}>
+        // CHECK: %[[VAR4:.*]] = vector.bitcast %[[VAR3]] : vector<8xi32> to vector<8xf32>
         %loaded = xegpu.load_nd %src_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
             : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
-        //CHECK: %[[LD_LOADED_F32:.*]] = vector.bitcast %[[LD_LOADED_I32]] : vector<8xi32> to vector<8xf32>
 
         %tid_x = gpu.thread_id x
         %tid_x_i32 = arith.index_cast %tid_x : index to i32
         %tid_x_f32 = arith.sitofp %tid_x_i32 : i32 to f32
-        //CHECK: %[[LOADED_F32_MODIFIED:.*]] = vector.insert %{{.*}}, %[[LD_LOADED_F32]] [0] : f32 into vector<8xf32>
+        // CHECK: %[[VAR7:.*]] = vector.insert
         %loaded_modified = vector.insert %tid_x_f32, %loaded[0] : f32 into vector<8xf32>
 
-        // CHECK: %[[PTR_AS_I64:.*]] = arith.index_castui {{.*}} : index to i64
-        // CHECK: %[[CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64>
-        // CHECK: %[[DESC_0:.*]] = vector.insert %[[PTR_AS_I64]], %[[CREATE_DESC_I64]] [0] : i64 into vector<4xi64>
-        // CHECK: %[[DESC_1:.*]] = vector.bitcast %[[DESC_0]] : vector<4xi64> to vector<8xi32>
-        // CHECK: %[[DESC_2:.*]] = vector.insert {{.*}}, %[[DESC_1]] [2] : i32 into vector<8xi32>
-        // CHECK: %[[DESC_3:.*]] = vector.insert {{.*}}, %[[DESC_2]] [3] : i32 into vector<8xi32>
-        // CHECK: %[[DESC_4:.*]] = vector.insert {{.*}}, %[[DESC_3]] [4] : i32 into vector<8xi32>
-        // CHECK: %[[DESC:.*]] = vector.insert {{.*}}, %[[DESC_4]] [5] : i32 into vector<8xi32>
         %dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
 
-        //CHECK: %[[DESC_I64:.*]] = vector.bitcast %[[DESC]] : vector<8xi32> to vector<4xi64>
-        //CHECK: %[[INTPTR:.*]] = vector.extract %[[DESC_I64]][0] : i64 from vector<4xi64>
-        //CHECK: %[[BASE_W:.*]] = vector.extract %[[DESC]][2] : i32 from vector<8xi32>
-        //CHECK: %[[BASE_H:.*]] = vector.extract %[[DESC]][3] : i32 from vector<8xi32>
-        //CHECK: %[[TILE_W64:.*]] = arith.constant 0 : i64
-        //CHECK: %[[TILE_W:.*]] = arith.trunci %[[TILE_W64]] : i64 to i32
-        //CHECK: %[[TILE_H64:.*]] = arith.constant 0 : i64
-        //CHECK: %[[TILE_H:.*]] = arith.trunci %[[TILE_H64]] : i64 to i32
-        //CHECK: %[[LLVMPTR:.*]] = llvm.inttoptr %[[INTPTR]] : i64 to !llvm.ptr<1>
-        //CHECK: %[[SIZEOF_F32:.*]] = arith.constant 4 : i32
-        //CHECK: %[[BASE_ROW_IN_BYTES:.*]] = arith.muli %[[BASE_W]], %[[SIZEOF_F32]] : i32
-        //CHECK: %[[FLAT_VALUE_I32:.*]] = vector.bitcast %[[LOADED_F32_MODIFIED]] : vector<8xf32> to vector<8xi32>
-        //CHECK: xevm.blockstore2d %[[LLVMPTR]], %[[BASE_ROW_IN_BYTES]], %[[BASE_H]], %[[BASE_ROW_IN_BYTES]],
-        //CHECK-SAME: %[[TILE_W]], %[[TILE_H]], %[[FLAT_VALUE_I32]]
-        //CHECK-SAME: <{cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
-        //CHECK-SAME:   tile_height = 8 : i32, tile_width = 16 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
+        // CHECK: %[[VAR8:.*]] = llvm.inttoptr %[[VAR1]] : i64 to !llvm.ptr<1>
+        // CHECK: %[[VAR9:.*]] = vector.bitcast %[[VAR7]] : vector<8xf32> to vector<8xi32>
+        // CHECK: xevm.blockstore2d %[[VAR8]], %[[C64_i32]], %[[C8_i32]], %[[C64_i32]], %[[C0_i32]], %[[C0_i32]], %[[VAR9]]
+        // CHECK-SAME: <{cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
+        // CHECK-SAME:   tile_height = 8 : i32, tile_width = 16 : i32}>
         xegpu.store_nd %loaded_modified, %dst_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
             : vector<8xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
         gpu.return
>From 05c889acaa07a47e3aa13de207a01a0a5c85788d Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Wed, 22 Oct 2025 10:52:43 -0700
Subject: [PATCH 5/9] Update test check.
---
 .../Conversion/XeGPUToXeVM/prefetch_nd.mlir   | 50 ++++++++-----------
 1 file changed, 22 insertions(+), 28 deletions(-)
diff --git a/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir
index ae87951a33447..09f2108cc5aed 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir
@@ -1,40 +1,34 @@
 // RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
 
 gpu.module @prefetch_nd_check {
+    // CHECK-LABEL: gpu.func @prefetch_nd(
+    // CHECK-SAME: %[[ARG0:.*]]: memref<8x16xf32, 1>, %[[ARG1:.*]]: memref<8x16xf32, 1>) kernel {
     gpu.func @prefetch_nd(%src: memref<8x16xf32, 1>, %dst: memref<8x16xf32, 1>) kernel {
+        // CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<8x16xf32, 1> to memref<8x16xf32>
+        // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<8x16xf32> -> index
+        // CHECK: %[[VAR0:.*]] = arith.index_castui %[[INTPTR]] : index to i64
         %srcce = memref.memory_space_cast %src : memref<8x16xf32, 1> to memref<8x16xf32>
+        // CHECK: %[[MEMSPACECAST_0:.*]] = memref.memory_space_cast %[[ARG1]] : memref<8x16xf32, 1> to memref<8x16xf32>
         %dstte = memref.memory_space_cast %dst : memref<8x16xf32, 1> to memref<8x16xf32>
 
-        // CHECK: %[[LD_PTR_AS_I64:.*]] = arith.index_castui {{.*}} : index to i64
-        // CHECK: %[[LD_CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64>
-        // CHECK: %[[LD_DESC_0:.*]] = vector.insert %[[LD_PTR_AS_I64]], %[[LD_CREATE_DESC_I64]] [0] : i64 into vector<4xi64>
-        // CHECK: %[[LD_DESC_1:.*]] = vector.bitcast %[[LD_DESC_0]] : vector<4xi64> to vector<8xi32>
-        // CHECK: %[[LD_DESC_2:.*]] = vector.insert {{.*}}, %[[LD_DESC_1]] [2] : i32 into vector<8xi32>
-        // CHECK: %[[LD_DESC_3:.*]] = vector.insert {{.*}}, %[[LD_DESC_2]] [3] : i32 into vector<8xi32>
-        // CHECK: %[[LD_DESC_4:.*]] = vector.insert {{.*}}, %[[LD_DESC_3]] [4] : i32 into vector<8xi32>
-        // CHECK: %[[LD_DESC:.*]] = vector.insert {{.*}}, %[[LD_DESC_4]] [5] : i32 into vector<8xi32>
         %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32,
-            #xegpu.block_tdesc_attr<memory_space = global>, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
-
-        //CHECK: %[[LD_DESC_I64:.*]] = vector.bitcast %[[LD_DESC]] : vector<8xi32> to vector<4xi64>
-        //CHECK: %[[PREF_INTPTR:.*]] = vector.extract %[[LD_DESC_I64]][0] : i64 from vector<4xi64>
-        //CHECK: %[[PREF_BASE_W:.*]] = vector.extract %[[LD_DESC]][2] : i32 from vector<8xi32>
-        //CHECK: %[[PREF_BASE_H:.*]] = vector.extract %[[LD_DESC]][3] : i32 from vector<8xi32>
-        //CHECK: %[[PREF_TILE_W64:.*]] = arith.constant 0 : i64
-        //CHECK: %[[PREF_TILE_W:.*]] = arith.trunci %[[PREF_TILE_W64]] : i64 to i32
-        //CHECK: %[[PREF_TILE_H64:.*]] = arith.constant 0 : i64
-        //CHECK: %[[PREF_TILE_H:.*]] = arith.trunci %[[PREF_TILE_H64]] : i64 to i32
-        //CHECK: %[[PREF_LLVMPTR:.*]] = llvm.inttoptr %[[PREF_INTPTR]] : i64 to !llvm.ptr<1>
-        //CHECK: %[[PREF_SIZEOF_F32:.*]] = arith.constant 4 : i32
-        //CHECK: %[[PREF_BASE_ROW_IN_BYTES:.*]] = arith.muli %[[PREF_BASE_W]], %[[PREF_SIZEOF_F32]] : i32
-        //CHECK: xevm.blockprefetch2d %[[PREF_LLVMPTR]], %[[PREF_BASE_ROW_IN_BYTES]], %[[PREF_BASE_H]],
-        //CHECK-SAME:   %[[PREF_BASE_ROW_IN_BYTES]], %[[PREF_TILE_W]], %[[PREF_TILE_H]]
-        //CHECK-SAME:   <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
-        //CHECK-SAME:     tile_height = 8 : i32, tile_width = 16 : i32, v_blocks = 1 : i32}>
-        //CHECK-SAME:   : (!llvm.ptr<1>, i32, i32, i32, i32, i32)
+            #xegpu.block_tdesc_attr<memory_space = global>>
+        // CHECK: %[[C16_I64:.*]] = arith.constant 16 : i64
+        // CHECK: %[[VAR1:.*]] = arith.trunci %[[C16_I64]] : i64 to i32
+        // CHECK: %[[C8_I64:.*]] = arith.constant 8 : i64
+        // CHECK: %[[VAR2:.*]] = arith.trunci %[[C8_I64]] : i64 to i32
+        // CHECK: %[[C0_I64:.*]] = arith.constant 0 : i64
+        // CHECK: %[[VAR3:.*]] = arith.trunci %[[C0_I64]] : i64 to i32
+        // CHECK: %[[C0_I64_1:.*]] = arith.constant 0 : i64
+        // CHECK: %[[VAR4:.*]] = arith.trunci %[[C0_I64_1]] : i64 to i32
+        // CHECK: %[[VAR5:.*]] = llvm.inttoptr %[[VAR0]] : i64 to !llvm.ptr<1>
+        // CHECK: %[[C4_I32:.*]] = arith.constant 4 : i32
+        // CHECK: %[[VAR6:.*]] = arith.muli %[[VAR1]], %[[C4_I32]] : i32
+        // CHECK: xevm.blockprefetch2d %[[VAR5]], %[[VAR6]], %[[VAR2]], %[[VAR6]], %[[VAR3]], %[[VAR4]]
+        // CHECK-SAME: <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
+        // CHECK-SAME:   tile_height = 8 : i32, tile_width = 16 : i32, v_blocks = 1 : i32}>
         xegpu.prefetch_nd %src_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
-            : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>,
-                  #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+            : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
 
         gpu.return
     }
>From 9a2ea5f5b45586bff911509739aa0cef66e61933 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Wed, 22 Oct 2025 17:55:53 +0000
Subject: [PATCH 6/9] Update test check.
---
 .../XeGPUToXeVM/create_nd_tdesc.mlir          | 48 ++++++-------------
 1 file changed, 14 insertions(+), 34 deletions(-)
diff --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
index d6e36fa73bf04..38d2c6483c204 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
@@ -4,45 +4,25 @@ gpu.module @create_nd_tdesc {
   // CHECK-LABEL: gpu.func @create_nd_tdesc
   // CHECK-SAME: %[[ARG0:.*]]: memref<16x32xf32, 1>, %[[ARG1:.*]]: ui64,
   // CHECK-SAME: %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index, %[[ARG6:.*]]: index, %[[ARG7:.*]]: index
+  // CHECK-SAME: %[[ARG8:.*]]: memref<?x?xf16>) kernel {
   gpu.func @create_nd_tdesc(%src: memref<16x32xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index,
-  %stride1: index, %stride2: index, %offset1: index, %offset2: index) kernel {
-        // CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index
-        // CHECK: %[[BASE_ADDR:.*]] = arith.index_castui %[[VAR0]] : index to i64
-        // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
-        // CHECK: %[[OFFSET_W:.*]] = arith.constant 0 : i32
-        // CHECK: %[[OFFSET_H:.*]] = arith.constant 0 : i32
-        // CHECK: %[[SHAPE_W:.*]] = arith.index_cast %[[ARG3]] : index to i32
-        // CHECK: %[[SHAPE_H:.*]] = arith.index_cast %[[ARG2]] : index to i32
-        // CHECK: %[[VAR6:.*]] = vector.bitcast %[[CST]] : vector<8xi32> to vector<4xi64>
-        // CHECK: %[[VAR7:.*]] = vector.insert %[[BASE_ADDR]], %[[VAR6]] [0] : i64 into vector<4xi64>
-        // CHECK: %[[VAR8:.*]] = vector.bitcast %[[VAR7]] : vector<4xi64> to vector<8xi32>
-        // CHECK: %[[VAR9:.*]] = vector.insert %[[SHAPE_W]], %[[VAR8]] [2] : i32 into vector<8xi32>
-        // CHECK: %[[VAR10:.*]] = vector.insert %[[SHAPE_H]], %[[VAR9]] [3] : i32 into vector<8xi32>
-        // CHECK: %[[VAR11:.*]] = vector.insert %[[OFFSET_W]], %[[VAR10]] [4] : i32 into vector<8xi32>
-        // CHECK: %[[VAR12:.*]] = vector.insert %[[OFFSET_H]], %[[VAR11]] [5] : i32 into vector<8xi32>
+  %stride1: index, %stride2: index, %offset1: index, %offset2: index, %dyn: memref<?x?xf16>) kernel {
+        // Optimized away
         %ptr_tdesc = xegpu.create_nd_tdesc %ptr, shape:[%shape1, %shape2], strides:[%stride1, %stride2]
             : ui64 -> !xegpu.tensor_desc<8x16xf32>
-
-        // CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<16x32xf32, 1> to memref<16x32xf32>
+        // CHECK-NEXT: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<16x32xf32, 1> to memref<16x32xf32>
         %srcce = memref.memory_space_cast %src : memref<16x32xf32, 1> to memref<16x32xf32>
-
-        // CHECK: %[[CST_1:.*]] = arith.constant dense<0> : vector<8xi32>
-        // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32> -> index
-        // CHECK: %[[OFFSET_W2:.*]] = arith.constant 0 : i32
-        // CHECK: %[[OFFSET_H2:.*]] = arith.constant 0 : i32
-        // CHECK: %[[C32_I64:.*]] = arith.constant 32 : i64
-        // CHECK: %[[SHAPE_W2:.*]] = arith.trunci %[[C32_I64]] : i64 to i32
-        // CHECK: %[[C16_I64:.*]] = arith.constant 16 : i64
-        // CHECK: %[[SHAPE_H2:.*]] = arith.trunci %[[C16_I64]] : i64 to i32
-        // CHECK: %[[BASE_ADDR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
-        // CHECK: %[[VAR14:.*]] = vector.bitcast %[[CST_1]] : vector<8xi32> to vector<4xi64>
-        // CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2]], %[[VAR14]] [0] : i64 into vector<4xi64>
-        // CHECK: %[[VAR16:.*]] = vector.bitcast %[[VAR15]] : vector<4xi64> to vector<8xi32>
-        // CHECK: %[[VAR17:.*]] = vector.insert %[[SHAPE_W2]], %[[VAR16]] [2] : i32 into vector<8xi32>
-        // CHECK: %[[VAR18:.*]] = vector.insert %[[SHAPE_H2]], %[[VAR17]] [3] : i32 into vector<8xi32>
-        // CHECK: %[[VAR19:.*]] = vector.insert %[[OFFSET_W2]], %[[VAR18]] [4] : i32 into vector<8xi32>
-        // CHECK: %[[PAYLOAD:.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32>
+        // Optimized away
         %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+        // CHECK-NEXT: %c1 = arith.constant 1 : index
+        %c1 = arith.constant 1 : index
+        // CHECK-NEXT: %c64 = arith.constant 64 : index
+        %size_x = arith.constant 64 : index
+        // CHECK-NEXT: %c16 = arith.constant 16 : index
+        %BLOCK_DMODEL = arith.constant 16 : index
+        // Optimized away
+        %dyn_tdesc  = xegpu.create_nd_tdesc %dyn, shape: [%size_x, %BLOCK_DMODEL], strides: [%BLOCK_DMODEL, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<16x16xf16>
+        // CHECK-NEXT: gpu.return
         gpu.return
     }
 }
>From a79bd40420abb3956a49cff69e44484d5fbc7887 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Wed, 22 Oct 2025 18:23:57 +0000
Subject: [PATCH 7/9] Update test check.
---
 .../loadstore_nd_high_base_rank.mlir          | 24 ++++++-
 .../XeGPUToXeVM/loadstore_nd_int_addr.mlir    | 72 ++++++-------------
 2 files changed, 45 insertions(+), 51 deletions(-)
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_high_base_rank.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_high_base_rank.mlir
index e328517634b03..d80f12c06a58a 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_high_base_rank.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_high_base_rank.mlir
@@ -1,23 +1,45 @@
 // RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s
 
 gpu.module @load_store_check {
-    // CHECK: fail
+  // CHECK-LABEL: gpu.func @load_store
+  // CHECK-SAME: %[[ARG0:.*]]: memref<3x3x8x16xf32, 1>, %[[ARG1:.*]]: memref<3x3x8x16xf32, 1>) kernel {
     gpu.func @load_store(%src: memref<3x3x8x16xf32, 1>, %dst: memref<3x3x8x16xf32, 1>) kernel {
+      // CHECK: %[[C32_I32:.*]] = arith.constant 32 : i32
+      // CHECK: %[[C64_I32:.*]] = arith.constant 64 : i32
+      // CHECK: %[[C0_I32:.*]] = arith.constant 0 : i32
+      // CHECK: %[[C72_I32:.*]] = arith.constant 72 : i32
+      // CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<3x3x8x16xf32, 1> to memref<3x3x8x16xf32>
+      // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<3x3x8x16xf32> -> index
+      // CHECK: %[[VAR0:.*]] = arith.index_castui %[[INTPTR]] : index to i64
         %srcce = memref.memory_space_cast %src : memref<3x3x8x16xf32, 1> to memref<3x3x8x16xf32>
+      // CHECK: %[[MEMSPACECAST_0:.*]] = memref.memory_space_cast %[[ARG1]] : memref<3x3x8x16xf32, 1> to memref<3x3x8x16xf32>
+      // CHECK: %[[INTPTR_1:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST_0]] : memref<3x3x8x16xf32> -> index
+      // CHECK: %[[VAR1:.*]] = arith.index_castui %[[INTPTR_1]] : index to i64
         %dstte = memref.memory_space_cast %dst : memref<3x3x8x16xf32, 1> to memref<3x3x8x16xf32>
 
         %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<3x3x8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
 
+      // CHECK: %[[VAR2:.*]] = llvm.inttoptr %[[VAR0]] : i64 to !llvm.ptr<1>
+      // CHECK: %[[LOADED:.*]] = xevm.blockload2d %[[VAR2]], %[[C64_I32]], %[[C72_I32]], %[[C64_I32]],
+      // CHECK-SAME: %[[C0_I32]], %[[C64_I32]] <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>,
+      // CHECK-SAME:  elem_size_in_bits = 32 : i32, pack_register = false, tile_height = 8 : i32,
+      // CHECK-SAME:  tile_width = 16 : i32, transpose = false, v_blocks = 1 : i32}>
         %loaded = xegpu.load_nd %src_tdesc[2, 2, 0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
             : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
 
         %tid_x = gpu.thread_id x
         %tid_x_i32 = arith.index_cast %tid_x : index to i32
         %tid_x_f32 = arith.sitofp %tid_x_i32 : i32 to f32
+      // CHECK: %[[VAR7:.*]] = vector.insert
         %loaded_modified = vector.insert %tid_x_f32, %loaded[0] : f32 into vector<8xf32>
 
         %dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<3x3x8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
 
+      // CHECK: %[[VAR8:.*]] = llvm.inttoptr %[[VAR1]] : i64 to !llvm.ptr<1>
+      // CHECK: %[[VAR9:.*]] = vector.bitcast %[[VAR7]] : vector<8xf32> to vector<8xi32>
+      // CHECK: xevm.blockstore2d %[[VAR8]], %[[C64_I32]], %[[C72_I32]], %[[C64_I32]], %[[C0_I32]], %[[C32_I32]], %[[VAR9]]
+      // CHECK-SAME: <{cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>,
+      // CHECK-SAME:  elem_size_in_bits = 32 : i32, tile_height = 8 : i32, tile_width = 16 : i32}>
         xegpu.store_nd %loaded_modified, %dst_tdesc[1, 1, 0, 0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
             : vector<8xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
         gpu.return
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_int_addr.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_int_addr.mlir
index ca458eaf231c0..c8ce0b3021b3f 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_int_addr.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_int_addr.mlir
@@ -1,73 +1,45 @@
 // RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s
 
 gpu.module @load_store_check {
+    // CHECK-LABEL: gpu.func @load_store
+    // CHECK-SAME: %[[ARG0:.*]]: ui64, %[[ARG1:.*]]: ui32) kernel {
     gpu.func @load_store(%src: ui64, %dst: ui32) kernel {
-        // CHECK: %[[LD_PTR_AS_I64:.*]] = arith.index_castui {{.*}} : index to i64
-        // CHECK: %[[LD_CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64>
-        // CHECK: %[[LD_DESC_0:.*]] = vector.insert %[[LD_PTR_AS_I64]], %[[LD_CREATE_DESC_I64]] [0] : i64 into vector<4xi64>
-        // CHECK: %[[LD_DESC_1:.*]] = vector.bitcast %[[LD_DESC_0]] : vector<4xi64> to vector<8xi32>
-        // CHECK: %[[LD_DESC_2:.*]] = vector.insert {{.*}}, %[[LD_DESC_1]] [2] : i32 into vector<8xi32>
-        // CHECK: %[[LD_DESC_3:.*]] = vector.insert {{.*}}, %[[LD_DESC_2]] [3] : i32 into vector<8xi32>
-        // CHECK: %[[LD_DESC_4:.*]] = vector.insert {{.*}}, %[[LD_DESC_3]] [4] : i32 into vector<8xi32>
-        // CHECK: %[[LD_DESC:.*]] = vector.insert {{.*}}, %[[LD_DESC_4]] [5] : i32 into vector<8xi32>
+        // CHECK: %[[C64_I32:.*]] = arith.constant 64 : i32
+        // CHECK: %[[C0_I32:.*]] = arith.constant 0
+        // CHECK: %[[C8_I32:.*]] = arith.constant 8 : i32
+        // CHECK: %[[ARG1_IDX:.*]] = index.castu %[[ARG1]] : ui32 to index
+        // CHECK: %[[ARG1_I32:.*]] = arith.index_castui %[[ARG1_IDX]] : index to i32
+        // CHECK: %[[ARG0_IDX:.*]] = index.castu %[[ARG0]] : ui64 to index
+        // CHECK: %[[ARG0_I64:.*]] = arith.index_castui %[[ARG0_IDX]] : index to i64
         %c8 = arith.constant 8 : index
         %c16 = arith.constant 16 : index
         %c1 = arith.constant 1 : index
         %src_tdesc = xegpu.create_nd_tdesc %src, shape:[%c8, %c16], strides:[%c16, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32>
 
 
-        //CHECK: %[[LD_DESC_I64:.*]] = vector.bitcast %[[LD_DESC]] : vector<8xi32> to vector<4xi64>
-        //CHECK: %[[LD_INTPTR:.*]] = vector.extract %[[LD_DESC_I64]][0] : i64 from vector<4xi64>
-        //CHECK: %[[LD_BASE_W:.*]] = vector.extract %[[LD_DESC]][2] : i32 from vector<8xi32>
-        //CHECK: %[[LD_BASE_H:.*]] = vector.extract %[[LD_DESC]][3] : i32 from vector<8xi32>
-        //CHECK: %[[LD_TILE_W64:.*]] = arith.constant 0 : i64
-        //CHECK: %[[LD_TILE_W:.*]] = arith.trunci %[[LD_TILE_W64]] : i64 to i32
-        //CHECK: %[[LD_TILE_H64:.*]] = arith.constant 0 : i64
-        //CHECK: %[[LD_TILE_H:.*]] = arith.trunci %[[LD_TILE_H64]] : i64 to i32
-        //CHECK: %[[LD_LLVMPTR:.*]] = llvm.inttoptr %[[LD_INTPTR]] : i64 to !llvm.ptr<1>
-        //CHECK: %[[LD_SIZEOF_F32:.*]] = arith.constant 4 : i32
-        //CHECK: %[[LD_BASE_ROW_IN_BYTES:.*]] = arith.muli %[[LD_BASE_W]], %[[LD_SIZEOF_F32]] : i32
-        //CHECK: %[[LD_LOADED_I32:.*]] = xevm.blockload2d %[[LD_LLVMPTR]], %[[LD_BASE_ROW_IN_BYTES]],
-        //CHECK-SAME: %[[LD_BASE_H]], %[[LD_BASE_ROW_IN_BYTES]], %[[LD_TILE_W]], %[[LD_TILE_H]]
-        //CHECK-SAME: <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
-        //CHECK-SAME:   pack_register = false, tile_height = 8 : i32, tile_width = 16 : i32, transpose = false,
-        //CHECK-SAME:   v_blocks = 1 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
+        // CHECK: %[[VAR4:.*]] = llvm.inttoptr %[[ARG0_I64]] : i64 to !llvm.ptr<1>
+        // CHECK: %[[LOAD:.*]] = xevm.blockload2d %[[VAR4]], %[[C64_I32]], %[[C8_I32]], %[[C64_I32]],
+        // CHECK-SAME:  %[[C0_I32]], %[[C0_I32]] <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>,
+        // CHECK-SAME:  elem_size_in_bits = 32 : i32, pack_register = false, tile_height = 8 : i32,
+        // CHECK-SAME:  tile_width = 16 : i32, transpose = false, v_blocks = 1 : i32}>
+        // CHECK: %[[VAR6:.*]] = vector.bitcast %[[LOAD]] : vector<8xi32> to vector<8xf32>
         %loaded = xegpu.load_nd %src_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
             : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
-        //CHECK: %[[LD_LOADED_F32:.*]] = vector.bitcast %[[LD_LOADED_I32]] : vector<8xi32> to vector<8xf32>
 
         %tid_x = gpu.thread_id x
         %tid_x_i32 = arith.index_cast %tid_x : index to i32
         %tid_x_f32 = arith.sitofp %tid_x_i32 : i32 to f32
-        //CHECK: %[[LOADED_F32_MODIFIED:.*]] = vector.insert %{{.*}}, %[[LD_LOADED_F32]] [0] : f32 into vector<8xf32>
+        // CHECK: %[[VAR9:.*]] = vector.insert
         %loaded_modified = vector.insert %tid_x_f32, %loaded[0] : f32 into vector<8xf32>
 
-        // CHECK: %[[PTR_AS_I64:.*]] = arith.index_castui {{.*}} : index to i64
-        // CHECK: %[[CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64>
-        // CHECK: %[[DESC_0:.*]] = vector.insert %[[PTR_AS_I64]], %[[CREATE_DESC_I64]] [0] : i64 into vector<4xi64>
-        // CHECK: %[[DESC_1:.*]] = vector.bitcast %[[DESC_0]] : vector<4xi64> to vector<8xi32>
-        // CHECK: %[[DESC_2:.*]] = vector.insert {{.*}}, %[[DESC_1]] [2] : i32 into vector<8xi32>
-        // CHECK: %[[DESC_3:.*]] = vector.insert {{.*}}, %[[DESC_2]] [3] : i32 into vector<8xi32>
-        // CHECK: %[[DESC_4:.*]] = vector.insert {{.*}}, %[[DESC_3]] [4] : i32 into vector<8xi32>
-        // CHECK: %[[DESC:.*]] = vector.insert {{.*}}, %[[DESC_4]] [5] : i32 into vector<8xi32>
+        // CHECK: %[[VAR10:.*]] = arith.extui %[[ARG1_I32]] : i32 to i64
         %dst_tdesc = xegpu.create_nd_tdesc %dst, shape:[%c8, %c16], strides:[%c16, %c1] : ui32 -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
 
-        //CHECK: %[[DESC_I64:.*]] = vector.bitcast %[[DESC]] : vector<8xi32> to vector<4xi64>
-        //CHECK: %[[INTPTR:.*]] = vector.extract %[[DESC_I64]][0] : i64 from vector<4xi64>
-        //CHECK: %[[BASE_W:.*]] = vector.extract %[[DESC]][2] : i32 from vector<8xi32>
-        //CHECK: %[[BASE_H:.*]] = vector.extract %[[DESC]][3] : i32 from vector<8xi32>
-        //CHECK: %[[TILE_W64:.*]] = arith.constant 0 : i64
-        //CHECK: %[[TILE_W:.*]] = arith.trunci %[[TILE_W64]] : i64 to i32
-        //CHECK: %[[TILE_H64:.*]] = arith.constant 0 : i64
-        //CHECK: %[[TILE_H:.*]] = arith.trunci %[[TILE_H64]] : i64 to i32
-        //CHECK: %[[LLVMPTR:.*]] = llvm.inttoptr %[[INTPTR]] : i64 to !llvm.ptr<1>
-        //CHECK: %[[SIZEOF_F32:.*]] = arith.constant 4 : i32
-        //CHECK: %[[BASE_ROW_IN_BYTES:.*]] = arith.muli %[[BASE_W]], %[[SIZEOF_F32]] : i32
-        //CHECK: %[[FLAT_VALUE_I32:.*]] = vector.bitcast %[[LOADED_F32_MODIFIED]] : vector<8xf32> to vector<8xi32>
-        //CHECK: xevm.blockstore2d %[[LLVMPTR]], %[[BASE_ROW_IN_BYTES]], %[[BASE_H]], %[[BASE_ROW_IN_BYTES]],
-        //CHECK-SAME: %[[TILE_W]], %[[TILE_H]], %[[FLAT_VALUE_I32]]
-        //CHECK-SAME: <{cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
-        //CHECK-SAME:   tile_height = 8 : i32, tile_width = 16 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
+        // CHECK: %[[VAR11:.*]] = llvm.inttoptr %[[VAR10]] : i64 to !llvm.ptr<1>
+        // CHECK: %[[STORE:.*]] = vector.bitcast %[[VAR9]] : vector<8xf32> to vector<8xi32>
+        // CHECK: xevm.blockstore2d %[[VAR11]], %[[C64_I32]], %[[C8_I32]], %[[C64_I32]], %[[C0_I32]], %[[C0_I32]], %[[STORE]]
+        // CHECK-SAME: <{cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>,
+        // CHECK-SAME:  elem_size_in_bits = 32 : i32, tile_height = 8 : i32, tile_width = 16 : i32}>
         xegpu.store_nd %loaded_modified, %dst_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
             : vector<8xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
         gpu.return
>From e053be1947ed9a35216b00afba3e1613066f6eb3 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Wed, 22 Oct 2025 18:58:10 +0000
Subject: [PATCH 8/9] Fix dynamic stride compute issue and add test case.
---
 mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index d7db65d662faf..38303e33e884c 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -143,7 +143,6 @@ translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
 }
 
 // Compute the product of sizes in the range [lo, hi) from the sizes array.
-// Note: all sizes are i64.
 static Value getProductOfSizes(ConversionPatternRewriter &rewriter,
                                Location loc, ArrayRef<OpFoldResult> sizes,
                                size_t lo, size_t hi) {
@@ -152,6 +151,8 @@ static Value getProductOfSizes(ConversionPatternRewriter &rewriter,
   for (size_t idx = lo; idx < hi; idx++) {
     OpFoldResult ofr = sizes[idx];
     Value sizeVal = getValueOrCreateConstantIntOp(rewriter, loc, ofr);
+    sizeVal = getValueOrCreateCastToIndexLike(rewriter, loc,
+                                              rewriter.getI64Type(), sizeVal);
     product = rewriter.createOrFold<arith::MulIOp>(loc, product, sizeVal);
   }
   return product;
>From 88ab9aaf6a60765f09508c73275bc093d1adc956 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Wed, 22 Oct 2025 19:17:07 +0000
Subject: [PATCH 9/9] Add more high rank base memory test cases.
---
 .../loadstore_nd_high_base_rank_dynamic.mlir  | 54 +++++++++++++++++++
 .../loadstore_nd_high_base_rank_int_addr.mlir | 52 ++++++++++++++++++
 2 files changed, 106 insertions(+)
 create mode 100644 mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_high_base_rank_dynamic.mlir
 create mode 100644 mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_high_base_rank_int_addr.mlir
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_high_base_rank_dynamic.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_high_base_rank_dynamic.mlir
new file mode 100644
index 0000000000000..16ecd978ad307
--- /dev/null
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_high_base_rank_dynamic.mlir
@@ -0,0 +1,54 @@
+// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s
+
+gpu.module @load_store_check {
+  // CHECK-LABEL: gpu.func @load_store
+  // CHECK-SAME: %[[ARG0:.*]]: memref<?x?x?x?xf32>, %[[ARG1:.*]]: memref<?x?x?x?xf32>) kernel {
+  gpu.func @load_store(%src: memref<?x?x?x?xf32>, %dst: memref<?x?x?x?xf32>) kernel {
+    // CHECK: %[[C32_I32:.*]] = arith.constant 32 : i32
+    // CHECK: %[[C64_I32:.*]] = arith.constant 64 : i32
+    // CHECK: %[[C0_I32:.*]] = arith.constant 0 : i32
+    // CHECK: %[[C72_I32:.*]] = arith.constant 72 : i32
+    // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] : memref<?x?x?x?xf32> -> index
+    // CHECK: %[[VAR0:.*]] = arith.index_castui %[[INTPTR]] : index to i64
+    // CHECK: %[[INTPTR_0:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<?x?x?x?xf32> -> index
+    // CHECK: %[[VAR1:.*]] = arith.index_castui %[[INTPTR_0]] : index to i64
+    %dim0 = arith.constant 3 : index
+    %dim1 = arith.constant 3 : index
+    %dim2 = arith.constant 8 : index
+    %dim3 = arith.constant 16 : index
+    %stride3 = arith.constant 1 : index
+    %stride2 = arith.constant 16 : index
+    %stride1 = arith.constant 128 : index
+    %stride0 = arith.constant 384 : index
+
+    %src_tdesc = xegpu.create_nd_tdesc %src, shape:[%dim0, %dim1, %dim2, %dim3],
+                   strides:[%stride0, %stride1, %stride2, %stride3] : memref<?x?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32>
+
+    // CHECK: %[[VAR2:.*]] = llvm.inttoptr %[[VAR1]] : i64 to !llvm.ptr<1>
+    // CHECK: %[[LOADED:.*]] = xevm.blockload2d %[[VAR2]], %[[C64_I32]], %[[C72_I32]], %[[C64_I32]],
+    // CHECK-SAME:  %[[C0_I32]], %[[C64_I32]] <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>,
+    // CHECK-SAME:  elem_size_in_bits = 32 : i32, pack_register = false, tile_height = 8 : i32,
+    // CHECK-SAME:  tile_width = 16 : i32, transpose = false, v_blocks = 1 : i32}>
+    // CHECK: %[[LOADED_F32:.*]] = vector.bitcast %[[LOADED]] : vector<8xi32> to vector<8xf32>
+    %loaded = xegpu.load_nd %src_tdesc[2, 2, 0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+            : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
+
+    %tid_x = gpu.thread_id x
+    %tid_x_i32 = arith.index_cast %tid_x : index to i32
+    %tid_x_f32 = arith.sitofp %tid_x_i32 : i32 to f32
+    // CHECK: %[[LOADED_MODIFIED:.*]] = vector.insert
+    %loaded_modified = vector.insert %tid_x_f32, %loaded[0] : f32 into vector<8xf32>
+
+    %dst_tdesc = xegpu.create_nd_tdesc %dst, shape:[%dim0, %dim1, %dim2, %dim3],
+                   strides:[%stride0, %stride1, %stride2, %stride3] : memref<?x?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
+
+    // CHECK: %[[VAR8:.*]] = llvm.inttoptr %[[VAR0]] : i64 to !llvm.ptr<1>
+    // CHECK: %[[LOADED_MODIFIED_BC:.*]] = vector.bitcast %[[LOADED_MODIFIED]] : vector<8xf32> to vector<8xi32>
+    // CHECK: xevm.blockstore2d %[[VAR8]], %[[C64_I32]], %[[C72_I32]], %[[C64_I32]],
+    // CHECK-SAME:  %[[C0_I32]], %[[C32_I32]], %[[LOADED_MODIFIED_BC]] <{cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>,
+    // CHECK-SAME:  elem_size_in_bits = 32 : i32, tile_height = 8 : i32, tile_width = 16 : i32}>
+    xegpu.store_nd %loaded_modified, %dst_tdesc[1, 1, 0, 0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
+            : vector<8xf32>, !xegpu.tensor_desc<8x16xf32>
+    gpu.return
+  }
+}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_high_base_rank_int_addr.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_high_base_rank_int_addr.mlir
new file mode 100644
index 0000000000000..428534c628314
--- /dev/null
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd_high_base_rank_int_addr.mlir
@@ -0,0 +1,52 @@
+// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize -cse %s | FileCheck %s
+
+gpu.module @load_store_check {
+  // CHECK-LABEL: gpu.func @load_store
+  // CHECK-SAME:  %[[ARG0:.*]]: i64, %[[ARG1:.*]]: i64,
+  // CHECK-SAME:  %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index, %[[ARG6:.*]]: index, %[[ARG7:.*]]: index, %[[ARG8:.*]]: index, %[[ARG9:.*]]: index
+  gpu.func @load_store(%src: i64, %dst: i64, %dim0: index, %dim1: index, %dim2: index, %dim3: index,
+                       %stride0: index, %stride1: index, %stride2: index, %stride3: index) kernel {
+    // CHECK: %[[C2:.*]] = arith.constant 2 : index
+    // CHECK: %[[C0_I32:.*]] = arith.constant 0 : i32
+    // CHECK: %[[C4_I32:.*]] = arith.constant 4 : i32
+    // CHECK: %[[VAR0:.*]] = arith.index_cast %[[ARG5]] : index to i32
+    // CHECK: %[[VAR1:.*]] = arith.index_cast %[[ARG2]] : index to i64
+    // CHECK: %[[VAR2:.*]] = arith.index_cast %[[ARG3]] : index to i64
+    // CHECK: %[[VAR3:.*]] = arith.muli %[[VAR1]], %[[VAR2]] : i64
+    // CHECK: %[[VAR4:.*]] = arith.index_cast %[[ARG4]] : index to i64
+    // CHECK: %[[VAR5:.*]] = arith.muli %[[VAR3]], %[[VAR4]] : i64
+    // CHECK: %[[VAR6:.*]] = arith.trunci %[[VAR5]] : i64 to i32
+    // CHECK: %[[VAR7:.*]] = arith.muli %[[ARG4]], %[[C2]] : index
+    // CHECK: %[[VAR8:.*]] = arith.muli %[[ARG4]], %[[ARG3]] : index
+    // CHECK: %[[VAR9:.*]] = arith.muli %[[VAR8]], %[[C2]] : index
+    // CHECK: %[[VAR10:.*]] = arith.addi %[[VAR7]], %[[VAR9]] : index
+    // CHECK: %[[VAR11:.*]] = arith.index_cast %[[VAR10]] : index to i32
+    %src_tdesc = xegpu.create_nd_tdesc %src, shape:[%dim0, %dim1, %dim2, %dim3],
+                   strides:[%stride0, %stride1, %stride2, %stride3] : i64 -> !xegpu.tensor_desc<8x16xf32>
+
+    // CHECK: %[[SRC_PTR:.*]] = llvm.inttoptr %[[ARG0]] : i64 to !llvm.ptr<1>
+    // CHECK: %[[VAR13:.*]] = arith.muli %[[VAR0]], %[[C4_I32]] : i32
+    // CHECK: %[[LOADED:.*]] = xevm.blockload2d %[[SRC_PTR]], %[[VAR13]], %[[VAR6]], %[[VAR13]], %[[C0_I32]], %[[VAR11]] <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 32 : i32, pack_register = false, tile_height = 8 : i32, tile_width = 16 : i32, transpose = false, v_blocks = 1 : i32}>
+    // CHECK: %[[VAR15:.*]] = vector.bitcast %[[LOADED]] : vector<8xi32> to vector<8xf32>
+    %loaded = xegpu.load_nd %src_tdesc[2, 2, 0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+            : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
+
+    %tid_x = gpu.thread_id x
+    %tid_x_i32 = arith.index_cast %tid_x : index to i32
+    %tid_x_f32 = arith.sitofp %tid_x_i32 : i32 to f32
+    // CHECK: %[[LOADED_MODIFIED:.*]] = vector.insert
+    %loaded_modified = vector.insert %tid_x_f32, %loaded[0] : f32 into vector<8xf32>
+
+    // CHECK: %[[VAR19:.*]] = arith.addi %[[ARG4]], %[[VAR8]] : index
+    // CHECK: %[[VAR20:.*]] = arith.index_cast %[[VAR19]] : index to i32
+    %dst_tdesc = xegpu.create_nd_tdesc %dst, shape:[%dim0, %dim1, %dim2, %dim3],
+                   strides:[%stride0, %stride1, %stride2, %stride3] : i64 -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
+
+    // CHECK: %[[DST_PTR:.*]] = llvm.inttoptr %[[ARG1]] : i64 to !llvm.ptr<1>
+    // CHECK: %[[LOADED_MODIFIED_BITCAST:.*]] = vector.bitcast %[[LOADED_MODIFIED]] : vector<8xf32> to vector<8xi32>
+    // CHECK: xevm.blockstore2d %[[DST_PTR]], %[[VAR13]], %[[VAR6]], %[[VAR13]], %[[C0_I32]], %[[VAR20]], %[[LOADED_MODIFIED_BITCAST]] <{cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>, elem_size_in_bits = 32 : i32, tile_height = 8 : i32, tile_width = 16 : i32}>
+    xegpu.store_nd %loaded_modified, %dst_tdesc[1, 1, 0, 0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
+            : vector<8xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
+    gpu.return
+  }
+}
    
    
More information about the Mlir-commits
mailing list