[Mlir-commits] [mlir] [MLIR][Conversion] XeGPU to XeVM: Add handler for 1D block ops (PR #165894)

Sang Ik Lee llvmlistbot at llvm.org
Mon Nov 10 08:49:53 PST 2025


https://github.com/silee2 updated https://github.com/llvm/llvm-project/pull/165894

>From 319205c9384078d663e375ed5fae57ea1db48b3a Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Thu, 30 Oct 2025 23:07:05 +0000
Subject: [PATCH 1/4] Isolate 2D code.

---
 .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp    | 128 +++++++++---------
 1 file changed, 64 insertions(+), 64 deletions(-)

diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 33e8f2ed1f6ed..a869ce09181a8 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -264,8 +264,6 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
 
     auto tdesc = adaptor.getTensorDesc();
     auto tdescTy = op.getTensorDescType();
-    if (tdescTy.getRank() != 2)
-      return rewriter.notifyMatchFailure(op, "Expected 2D tensor descriptor.");
     auto elemType = tdescTy.getElementType();
     auto elemBitSize = elemType.getIntOrFloatBitWidth();
     if (elemBitSize % 8 != 0)
@@ -294,71 +292,73 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
     // Get address space from tensor descriptor memory space.
     auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
         ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
-    // Convert base pointer (i64) to LLVM pointer type.
-    Value basePtrLLVM =
-        LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
-    // Compute element byte size and surface width in bytes.
-    Value elemByteSize = arith::ConstantIntOp::create(
-        rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
-    Value surfaceW =
-        arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
-
-    // Get tile sizes and vblocks from the tensor descriptor type.
-    auto tileW = tdescTy.getDimSize(1);
-    auto tileH = tdescTy.getDimSize(0);
-    int32_t vblocks = tdescTy.getArrayLength();
-    if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
-      Value src = adaptor.getValue();
-      // If store value is a scalar, get value from op instead of adaptor.
-      // Adaptor might have optimized away single element vector
-      if (src.getType().isIntOrFloat()) {
-        src = op.getValue();
-      }
-      VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
-      if (!srcVecTy)
-        return rewriter.notifyMatchFailure(
-            op, "Expected store value to be a vector type.");
-      // Get flat vector type of integer type with matching element bit size.
-      VectorType newSrcVecTy =
-          encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
-      if (srcVecTy != newSrcVecTy)
-        src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
-      auto storeCacheControl =
-          translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
-      xevm::BlockStore2dOp::create(
-          rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
-          offsetH, elemBitSize, tileW, tileH, src,
-          xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
-      rewriter.eraseOp(op);
-    } else {
-      auto loadCacheControl =
-          translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
-      if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
-        xevm::BlockPrefetch2dOp::create(
+    if (tdescTy.getRank() == 2) {
+      // Convert base pointer (i64) to LLVM pointer type.
+      Value basePtrLLVM =
+          LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
+      // Compute element byte size and surface width in bytes.
+      Value elemByteSize = arith::ConstantIntOp::create(
+          rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
+      Value surfaceW =
+          arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
+
+      // Get tile sizes and vblocks from the tensor descriptor type.
+      auto tileW = tdescTy.getDimSize(1);
+      auto tileH = tdescTy.getDimSize(0);
+      int32_t vblocks = tdescTy.getArrayLength();
+      if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
+        Value src = adaptor.getValue();
+        // If store value is a scalar, get value from op instead of adaptor.
+        // Adaptor might have optimized away single element vector
+        if (src.getType().isIntOrFloat()) {
+          src = op.getValue();
+        }
+        VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
+        if (!srcVecTy)
+          return rewriter.notifyMatchFailure(
+              op, "Expected store value to be a vector type.");
+        // Get flat vector type of integer type with matching element bit size.
+        VectorType newSrcVecTy =
+            encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
+        if (srcVecTy != newSrcVecTy)
+          src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
+        auto storeCacheControl =
+            translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
+        xevm::BlockStore2dOp::create(
             rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
-            offsetH, elemBitSize, tileW, tileH, vblocks,
-            xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
+            offsetH, elemBitSize, tileW, tileH, src,
+            xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
         rewriter.eraseOp(op);
       } else {
-        VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
-        const bool vnni = op.getPacked().value_or(false);
-        auto transposeValue = op.getTranspose();
-        bool transpose =
-            transposeValue.has_value() && transposeValue.value()[0] == 1;
-        VectorType loadedTy = encodeVectorTypeTo(
-            dstVecTy, vnni ? rewriter.getI32Type()
-                           : rewriter.getIntegerType(elemBitSize));
-
-        Value resultFlatVec = xevm::BlockLoad2dOp::create(
-            rewriter, loc, loadedTy, basePtrLLVM, surfaceW, baseShapeH,
-            surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
-            transpose, vnni,
-            xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
-        resultFlatVec = vector::BitCastOp::create(
-            rewriter, loc,
-            encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()),
-            resultFlatVec);
-        rewriter.replaceOp(op, resultFlatVec);
+        auto loadCacheControl =
+            translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
+        if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
+          xevm::BlockPrefetch2dOp::create(
+              rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW,
+              offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
+              xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
+          rewriter.eraseOp(op);
+        } else {
+          VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
+          const bool vnni = op.getPacked().value_or(false);
+          auto transposeValue = op.getTranspose();
+          bool transpose =
+              transposeValue.has_value() && transposeValue.value()[0] == 1;
+          VectorType loadedTy = encodeVectorTypeTo(
+              dstVecTy, vnni ? rewriter.getI32Type()
+                             : rewriter.getIntegerType(elemBitSize));
+
+          Value resultFlatVec = xevm::BlockLoad2dOp::create(
+              rewriter, loc, loadedTy, basePtrLLVM, surfaceW, baseShapeH,
+              surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
+              transpose, vnni,
+              xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
+          resultFlatVec = vector::BitCastOp::create(
+              rewriter, loc,
+              encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()),
+              resultFlatVec);
+          rewriter.replaceOp(op, resultFlatVec);
+        }
       }
     }
     return success();

>From fd3186e3311fbc314963fd2372aad117113311c8 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Thu, 30 Oct 2025 23:33:40 +0000
Subject: [PATCH 2/4] Add handler for 1D block load_nd/store_nd and test case.

---
 .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp    | 73 ++++++++++++++++++-
 1 file changed, 69 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index a869ce09181a8..bc360c1a6929b 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -292,19 +292,25 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
     // Get address space from tensor descriptor memory space.
     auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
         ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
-    if (tdescTy.getRank() == 2) {
+    // Compute element byte size.
+    Value elemByteSize = arith::ConstantIntOp::create(
+        rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
+    auto tileRank = tdescTy.getRank();
+    // Get tile width from the tensor descriptor type.
+    auto tileW = tdescTy.getDimSize(tileRank - 1);
+    if (tileRank == 2) {
       // Convert base pointer (i64) to LLVM pointer type.
       Value basePtrLLVM =
           LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
-      // Compute element byte size and surface width in bytes.
+      // Compute width in bytes.
       Value elemByteSize = arith::ConstantIntOp::create(
           rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
       Value surfaceW =
           arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
 
-      // Get tile sizes and vblocks from the tensor descriptor type.
-      auto tileW = tdescTy.getDimSize(1);
+      // Get tile height from the tensor descriptor type.
       auto tileH = tdescTy.getDimSize(0);
+      // Get vblocks from the tensor descriptor type.
       int32_t vblocks = tdescTy.getArrayLength();
       if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
         Value src = adaptor.getValue();
@@ -360,6 +366,65 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
           rewriter.replaceOp(op, resultFlatVec);
         }
       }
+    } else {
+      // Get address from base address and offsets.
+      // Offset in number of elements, need to multiply by element byte size.
+      // Compute linear offset.
+      // linearOffset = offsetH * baseShapeW + offsetW
+      Value offsetHInElems =
+          rewriter.createOrFold<arith::MulIOp>(loc, offsetH, baseShapeW);
+      Value linearOffset =
+          rewriter.createOrFold<arith::AddIOp>(loc, offsetHInElems, offsetW);
+      // Then compute byte offset by multiplying with element byte size.
+      // byteOffset = linearOffset * elemByteSize
+      Value byteOffset =
+          rewriter.createOrFold<arith::MulIOp>(loc, linearOffset, elemByteSize);
+      // Final address = basePtr + byteOffset
+      Value finalAddrI64 = rewriter.createOrFold<arith::AddIOp>(
+          loc, basePtr,
+          getValueOrCreateCastToIndexLike(rewriter, loc, rewriter.getI64Type(),
+                                          byteOffset));
+      // Convert base pointer (i64) to LLVM pointer type.
+      Value finalPtrLLVM =
+          LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, finalAddrI64);
+      if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
+        Value src = adaptor.getValue();
+        // If store value is a scalar, get value from op instead of adaptor.
+        // Adaptor might have optimized away single element vector
+        if (src.getType().isIntOrFloat()) {
+          src = op.getValue();
+        }
+        VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
+        if (!srcVecTy)
+          return rewriter.notifyMatchFailure(
+              op, "Expected store value to be a vector type.");
+        // Get flat vector type of integer type with matching element bit size.
+        VectorType newSrcVecTy =
+            encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
+        if (srcVecTy != newSrcVecTy)
+          src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
+        auto storeCacheControl =
+            translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
+        rewriter.replaceOpWithNewOp<xevm::BlockStoreOp>(
+            op, finalPtrLLVM, src,
+            xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
+      } else if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
+        auto loadCacheControl =
+            translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
+        VectorType resTy = cast<VectorType>(op.getValue().getType());
+        VectorType loadedTy =
+            encodeVectorTypeTo(resTy, rewriter.getIntegerType(elemBitSize));
+        Value load = xevm::BlockLoadOp::create(
+            rewriter, loc, loadedTy, finalPtrLLVM,
+            xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
+        if (loadedTy != resTy)
+          load = vector::BitCastOp::create(rewriter, loc, resTy, load);
+        rewriter.replaceOp(op, load);
+      } else {
+        return rewriter.notifyMatchFailure(
+            op, "Unsupported operation: xegpu.prefetch_nd with tensor "
+                "descriptor rank == 1");
+      }
     }
     return success();
   }

>From b2c9ffad3ece562c3719ac2af65cc212623c126b Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Fri, 31 Oct 2025 18:07:41 +0000
Subject: [PATCH 3/4] Add missing file.

---
 .../Conversion/XeGPUToXeVM/loadstore_1d.mlir  | 57 +++++++++++++++++++
 1 file changed, 57 insertions(+)
 create mode 100644 mlir/test/Conversion/XeGPUToXeVM/loadstore_1d.mlir

diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_1d.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_1d.mlir
new file mode 100644
index 0000000000000..54804ee6cd033
--- /dev/null
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_1d.mlir
@@ -0,0 +1,57 @@
+// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s
+
+gpu.module @load_store_check {
+    // CHECK-LABEL: @load_store(
+    // CHECK-SAME: %[[SRC:.*]]: memref<8x64xf32, 1>, %[[DST:.*]]: memref<8x32xf32, 1>
+    gpu.func @load_store(%src: memref<8x64xf32, 1>, %dst: memref<8x32xf32, 1>) kernel {
+        // CHECK: %[[C512:.*]] = arith.constant 512 : i64
+        // CHECK: %[[C32:.*]] = arith.constant 32 : i32
+        // CHECK: %[[C384:.*]] = arith.constant 384 : i64
+        // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<4xi64>
+        // CHECK: %[[C8:.*]] = arith.constant 8 : i32
+        // CHECK: %[[C64:.*]] = arith.constant 64 : i32
+        // CHECK: %[[C0:.*]] = arith.constant 0 : i32
+
+        // CHECK: %[[SRCCE:.*]] = memref.memory_space_cast %[[SRC]] : memref<8x64xf32, 1> to memref<8x64xf32>
+        %srcce = memref.memory_space_cast %src : memref<8x64xf32, 1> to memref<8x64xf32>
+        // CHECK: %[[DSTTE:.*]] = memref.memory_space_cast %[[DST]] : memref<8x32xf32, 1> to memref<8x32xf32>
+        %dstte = memref.memory_space_cast %dst : memref<8x32xf32, 1> to memref<8x32xf32>
+
+        // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[SRCCE]] : memref<8x64xf32> -> index
+        // CHECK: %[[INTPTR_I64:.*]] = arith.index_castui %[[INTPTR]] : index to i64
+        // CHECK: %[[VEC1:.*]] = vector.insert %[[INTPTR_I64]], %[[CST]] [0] : i64 into vector<4xi64>
+        // CHECK: %[[VEC2:.*]] = vector.bitcast %[[VEC1]] : vector<4xi64> to vector<8xi32>
+        // CHECK: %[[VEC3:.*]] = vector.insert %[[C64]], %[[VEC2]] [2] : i32 into vector<8xi32>
+        // CHECK: %[[VEC4:.*]] = vector.insert %[[C8]], %[[VEC3]] [3] : i32 into vector<8xi32>
+        // CHECK: %[[VEC5:.*]] = vector.insert %[[C0]], %[[VEC4]] [4] : i32 into vector<8xi32>
+        // CHECK: %[[VEC6:.*]] = vector.insert %[[C0]], %[[VEC5]] [5] : i32 into vector<8xi32>
+        %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x64xf32> -> !xegpu.tensor_desc<32xf32>
+        // CHECK: %[[VEC7:.*]] = vector.bitcast %[[VEC6]] : vector<8xi32> to vector<4xi64>
+        // CHECK: %[[EXTR:.*]] = vector.extract %[[VEC7]][0] : i64 from vector<4xi64>
+        // CHECK: %[[ADDR:.*]] = arith.addi %[[EXTR]], %[[C384]] : i64
+        // CHECK: %[[PTR:.*]] = llvm.inttoptr %[[ADDR]] : i64 to !llvm.ptr<1>
+        // CHECK: %[[LOAD:.*]] = xevm.blockload %[[PTR]] <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}>
+        // CHECK-SAME: : (!llvm.ptr<1>) -> vector<2xi32>
+        %loaded = xegpu.load_nd %src_tdesc[1, 32] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+            : !xegpu.tensor_desc<32xf32> -> vector<2xf32>
+
+        // CHECK: %[[INTPTR1:.*]] = memref.extract_aligned_pointer_as_index %[[DSTTE]] : memref<8x32xf32> -> index
+        // CHECK: %[[INTPTR1_I64:.*]] = arith.index_castui %[[INTPTR1]] : index to i64
+        // CHECK: %[[VEC1_1:.*]] = vector.insert %[[INTPTR1_I64]], %[[CST]] [0] : i64 into vector<4xi64>
+        // CHECK: %[[VEC2_1:.*]] = vector.bitcast %[[VEC1_1]] : vector<4xi64> to vector<8xi32>
+        // CHECK: %[[VEC3_1:.*]] = vector.insert %[[C32]], %[[VEC2_1]] [2] : i32 into vector<8xi32>
+        // CHECK: %[[VEC4_1:.*]] = vector.insert %[[C8]], %[[VEC3_1]] [3] : i32 into vector<8xi32>
+        // CHECK: %[[VEC5_1:.*]] = vector.insert %[[C0]], %[[VEC4_1]] [4] : i32 into vector<8xi32>
+        // CHECK: %[[VEC6_1:.*]] = vector.insert %[[C0]], %[[VEC5_1]] [5] : i32 into vector<8xi32>
+        %dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<8x32xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.block_tdesc_attr<memory_space = global>>
+        // CHECK: %[[VEC7_1:.*]] = vector.bitcast %[[VEC6_1]] : vector<8xi32> to vector<4xi64>
+        // CHECK: %[[EXTR1:.*]] = vector.extract %[[VEC7_1]][0] : i64 from vector<4xi64>
+        // CHECK: %[[ADDR1:.*]] = arith.addi %[[EXTR1]], %[[C512]] : i64
+        // CHECK: %[[PTR1:.*]] = llvm.inttoptr %[[ADDR1]] : i64 to !llvm.ptr<1>
+        // CHECK: xevm.blockstore %[[PTR1]], %[[LOAD]] <{cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>}>
+        // CHECK-SAME: : (!llvm.ptr<1>, vector<2xi32>)
+        xegpu.store_nd %loaded, %dst_tdesc[4, 0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
+            : vector<2xf32>, !xegpu.tensor_desc<32xf32, #xegpu.block_tdesc_attr<memory_space = global>>
+        gpu.return
+    }
+}

>From c7a59a92af7315229a21d69083b2e1b4cc9037c8 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Tue, 4 Nov 2025 18:46:27 +0000
Subject: [PATCH 4/4] Address reviewer comments.

---
 .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp    | 106 +++++++++---------
 .../XeGPUToXeVM/create_nd_tdesc.mlir          |   4 +-
 .../Conversion/XeGPUToXeVM/loadstore_1d.mlir  |  49 +++-----
 .../Conversion/XeGPUToXeVM/loadstore_nd.mlir  |   4 +-
 .../Conversion/XeGPUToXeVM/prefetch_nd.mlir   |   2 +-
 5 files changed, 75 insertions(+), 90 deletions(-)

diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index bc360c1a6929b..9b4c620bf518d 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -186,9 +186,6 @@ class CreateNdDescToXeVMPattern
     SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
     // Descriptor shape is expected to be 2D.
     int64_t rank = mixedSizes.size();
-    if (rank != 2)
-      return rewriter.notifyMatchFailure(op, "Expected 2D shape.");
-
     auto sourceTy = source.getType();
     auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
     // If source is a memref, we need to extract the aligned pointer as index.
@@ -199,8 +196,19 @@ class CreateNdDescToXeVMPattern
       }
       baseAddr =
           memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source);
+      // Cast index to i64.
+      baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr);
     } else {
       baseAddr = adaptor.getSource();
+      if (baseAddr.getType() != i64Ty) {
+        // Pointer type may be i32. Cast to i64 if needed.
+        baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr);
+      }
+    }
+    // 1D tensor descriptor is just the base address.
+    if (rank == 1) {
+      rewriter.replaceOp(op, baseAddr);
+      return success();
     }
     // Utility for creating offset values from op fold result.
     auto createOffset = [&](SmallVector<OpFoldResult> &ofrVec,
@@ -215,13 +223,6 @@ class CreateNdDescToXeVMPattern
     // Get shape values from op fold results.
     baseShapeW = createOffset(mixedSizes, 1);
     baseShapeH = createOffset(mixedSizes, 0);
-    if (sourceMemrefTy) {
-      // Cast index to i64.
-      baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr);
-    } else if (baseAddr.getType() != i64Ty) {
-      // Pointer type may be i32. Cast to i64 if needed.
-      baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr);
-    }
     // Populate payload.
     Value payLoadAsI64 =
         vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
@@ -257,57 +258,57 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
                   ConversionPatternRewriter &rewriter) const override {
     auto mixedOffsets = op.getMixedOffsets();
     int64_t opOffsetsSize = mixedOffsets.size();
-    if (opOffsetsSize != 2)
-      return rewriter.notifyMatchFailure(op, "Expected 2D offsets.");
     auto loc = op.getLoc();
     auto ctxt = rewriter.getContext();
 
     auto tdesc = adaptor.getTensorDesc();
     auto tdescTy = op.getTensorDescType();
+    auto tileRank = tdescTy.getRank();
+    if (opOffsetsSize != tileRank)
+      return rewriter.notifyMatchFailure(
+          op, "Expected offset rank to match descriptor rank.");
     auto elemType = tdescTy.getElementType();
     auto elemBitSize = elemType.getIntOrFloatBitWidth();
     if (elemBitSize % 8 != 0)
       return rewriter.notifyMatchFailure(
           op, "Expected element type bit width to be multiple of 8.");
 
-    VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
-    Value payLoadAsI64 =
-        vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc);
-    Value basePtr = vector::ExtractOp::create(
-        rewriter, loc, payLoadAsI64, static_cast<int>(NdTdescOffset::BasePtr));
-    Value baseShapeW = vector::ExtractOp::create(
-        rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW));
-    Value baseShapeH = vector::ExtractOp::create(
-        rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH));
-    // Offsets are provided by the op.
-    // convert them to i32.
-    Value offsetW =
-        getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]);
-    offsetW = getValueOrCreateCastToIndexLike(rewriter, loc,
-                                              rewriter.getI32Type(), offsetW);
-    Value offsetH =
-        getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
-    offsetH = getValueOrCreateCastToIndexLike(rewriter, loc,
-                                              rewriter.getI32Type(), offsetH);
     // Get address space from tensor descriptor memory space.
     auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
         ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
-    // Compute element byte size.
-    Value elemByteSize = arith::ConstantIntOp::create(
-        rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
-    auto tileRank = tdescTy.getRank();
-    // Get tile width from the tensor descriptor type.
-    auto tileW = tdescTy.getDimSize(tileRank - 1);
     if (tileRank == 2) {
+      // Compute element byte size.
+      Value elemByteSize = arith::ConstantIntOp::create(
+          rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
+      VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
+      Value payLoadAsI64 =
+          vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc);
+      Value basePtr =
+          vector::ExtractOp::create(rewriter, loc, payLoadAsI64,
+                                    static_cast<int>(NdTdescOffset::BasePtr));
+      Value baseShapeW = vector::ExtractOp::create(
+          rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW));
+      Value baseShapeH = vector::ExtractOp::create(
+          rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH));
+      // Offsets are provided by the op.
+      // convert them to i32.
+      Value offsetW =
+          getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]);
+      offsetW = getValueOrCreateCastToIndexLike(rewriter, loc,
+                                                rewriter.getI32Type(), offsetW);
+      Value offsetH =
+          getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
+      offsetH = getValueOrCreateCastToIndexLike(rewriter, loc,
+                                                rewriter.getI32Type(), offsetH);
       // Convert base pointer (i64) to LLVM pointer type.
       Value basePtrLLVM =
           LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
       // Compute width in bytes.
-      Value elemByteSize = arith::ConstantIntOp::create(
-          rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
       Value surfaceW =
           arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
 
+      // Get tile width from the tensor descriptor type.
+      auto tileW = tdescTy.getDimSize(tileRank - 1);
       // Get tile height from the tensor descriptor type.
       auto tileH = tdescTy.getDimSize(0);
       // Get vblocks from the tensor descriptor type.
@@ -367,21 +368,23 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
         }
       }
     } else {
-      // Get address from base address and offsets.
+      // 1D tensor descriptor.
+      // `tdesc` represents base address as i64
       // Offset in number of elements, need to multiply by element byte size.
-      // Compute linear offset.
-      // linearOffset = offsetH * baseShapeW + offsetW
-      Value offsetHInElems =
-          rewriter.createOrFold<arith::MulIOp>(loc, offsetH, baseShapeW);
-      Value linearOffset =
-          rewriter.createOrFold<arith::AddIOp>(loc, offsetHInElems, offsetW);
-      // Then compute byte offset by multiplying with element byte size.
-      // byteOffset = linearOffset * elemByteSize
+      // Compute byte offset.
+      //   byteOffset = offset * elementByteSize
+      Value offset =
+          getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
+      offset = getValueOrCreateCastToIndexLike(rewriter, loc,
+                                               rewriter.getI64Type(), offset);
+      // Compute element byte size.
+      Value elemByteSize = arith::ConstantIntOp::create(
+          rewriter, loc, rewriter.getI64Type(), elemBitSize / 8);
       Value byteOffset =
-          rewriter.createOrFold<arith::MulIOp>(loc, linearOffset, elemByteSize);
+          rewriter.createOrFold<arith::MulIOp>(loc, offset, elemByteSize);
       // Final address = basePtr + byteOffset
       Value finalAddrI64 = rewriter.createOrFold<arith::AddIOp>(
-          loc, basePtr,
+          loc, tdesc,
           getValueOrCreateCastToIndexLike(rewriter, loc, rewriter.getI64Type(),
                                           byteOffset));
       // Convert base pointer (i64) to LLVM pointer type.
@@ -992,7 +995,10 @@ struct ConvertXeGPUToXeVMPass
       return VectorType::get(sum, elemType);
     });
     typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type {
+      // Scattered descriptors are not supported in XeVM lowering.
       if (type.isScattered())
+        return {};
+      if (type.getRank() == 1)
         return IntegerType::get(&getContext(), 64);
       auto i32Type = IntegerType::get(&getContext(), 32);
       return VectorType::get(8, i32Type);
diff --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
index 09ef76c9d1740..109312218afae 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
@@ -29,13 +29,13 @@ gpu.module @create_nd_tdesc {
 
         // CHECK: %[[CST_1:.*]] = arith.constant dense<0> : vector<8xi32>
         // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32> -> index
+        // CHECK: %[[BASE_ADDR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
         // CHECK: %[[OFFSET_W2:.*]] = arith.constant 0 : i32
         // CHECK: %[[OFFSET_H2:.*]] = arith.constant 0 : i32
         // CHECK: %[[C32_I64:.*]] = arith.constant 32 : i64
         // CHECK: %[[SHAPE_W2:.*]] = arith.trunci %[[C32_I64]] : i64 to i32
         // CHECK: %[[C16_I64:.*]] = arith.constant 16 : i64
         // CHECK: %[[SHAPE_H2:.*]] = arith.trunci %[[C16_I64]] : i64 to i32
-        // CHECK: %[[BASE_ADDR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
         // CHECK: %[[VAR14:.*]] = vector.bitcast %[[CST_1]] : vector<8xi32> to vector<4xi64>
         // CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2]], %[[VAR14]] [0] : i64 into vector<4xi64>
         // CHECK: %[[VAR16:.*]] = vector.bitcast %[[VAR15]] : vector<4xi64> to vector<8xi32>
@@ -53,11 +53,11 @@ gpu.module @create_nd_tdesc {
         %BLOCK_DMODEL = arith.constant 16 : index
         // CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32>
         // CHECK: %[[INTPTR_5:.*]] = memref.extract_aligned_pointer_as_index %[[DYN]] : memref<?x?xf16> -> index
+        // CHECK: %[[VAR23:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64
         // CHECK: %[[C0_I32_6:.*]] = arith.constant 0 : i32
         // CHECK: %[[C0_I32_7:.*]] = arith.constant 0 : i32
         // CHECK: %[[VAR21:.*]] = arith.index_cast %[[C16]] : index to i32
         // CHECK: %[[VAR22:.*]] = arith.index_cast %[[C64]] : index to i32
-        // CHECK: %[[VAR23:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64
         // CHECK: %[[VAR24:.*]] = vector.bitcast %[[CST_4]] : vector<8xi32> to vector<4xi64>
         // CHECK: %[[VAR25:.*]] = vector.insert %[[VAR23]], %[[VAR24]] [0] : i64 into vector<4xi64>
         // CHECK: %[[VAR26:.*]] = vector.bitcast %[[VAR25]] : vector<4xi64> to vector<8xi32>
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_1d.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_1d.mlir
index 54804ee6cd033..7b4ad9ec2df03 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_1d.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_1d.mlir
@@ -2,55 +2,34 @@
 
 gpu.module @load_store_check {
     // CHECK-LABEL: @load_store(
-    // CHECK-SAME: %[[SRC:.*]]: memref<8x64xf32, 1>, %[[DST:.*]]: memref<8x32xf32, 1>
-    gpu.func @load_store(%src: memref<8x64xf32, 1>, %dst: memref<8x32xf32, 1>) kernel {
+    // CHECK-SAME: %[[SRC:.*]]: memref<512xf32, 1>, %[[DST:.*]]: memref<256xf32, 1>
+    gpu.func @load_store(%src: memref<512xf32, 1>, %dst: memref<256xf32, 1>) kernel {
         // CHECK: %[[C512:.*]] = arith.constant 512 : i64
-        // CHECK: %[[C32:.*]] = arith.constant 32 : i32
         // CHECK: %[[C384:.*]] = arith.constant 384 : i64
-        // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<4xi64>
-        // CHECK: %[[C8:.*]] = arith.constant 8 : i32
-        // CHECK: %[[C64:.*]] = arith.constant 64 : i32
-        // CHECK: %[[C0:.*]] = arith.constant 0 : i32
 
-        // CHECK: %[[SRCCE:.*]] = memref.memory_space_cast %[[SRC]] : memref<8x64xf32, 1> to memref<8x64xf32>
-        %srcce = memref.memory_space_cast %src : memref<8x64xf32, 1> to memref<8x64xf32>
-        // CHECK: %[[DSTTE:.*]] = memref.memory_space_cast %[[DST]] : memref<8x32xf32, 1> to memref<8x32xf32>
-        %dstte = memref.memory_space_cast %dst : memref<8x32xf32, 1> to memref<8x32xf32>
+        // CHECK: %[[SRCCE:.*]] = memref.memory_space_cast %[[SRC]] : memref<512xf32, 1> to memref<512xf32>
+        %srcce = memref.memory_space_cast %src : memref<512xf32, 1> to memref<512xf32>
+        // CHECK: %[[DSTTE:.*]] = memref.memory_space_cast %[[DST]] : memref<256xf32, 1> to memref<256xf32>
+        %dstte = memref.memory_space_cast %dst : memref<256xf32, 1> to memref<256xf32>
 
-        // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[SRCCE]] : memref<8x64xf32> -> index
+        // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[SRCCE]] : memref<512xf32> -> index
         // CHECK: %[[INTPTR_I64:.*]] = arith.index_castui %[[INTPTR]] : index to i64
-        // CHECK: %[[VEC1:.*]] = vector.insert %[[INTPTR_I64]], %[[CST]] [0] : i64 into vector<4xi64>
-        // CHECK: %[[VEC2:.*]] = vector.bitcast %[[VEC1]] : vector<4xi64> to vector<8xi32>
-        // CHECK: %[[VEC3:.*]] = vector.insert %[[C64]], %[[VEC2]] [2] : i32 into vector<8xi32>
-        // CHECK: %[[VEC4:.*]] = vector.insert %[[C8]], %[[VEC3]] [3] : i32 into vector<8xi32>
-        // CHECK: %[[VEC5:.*]] = vector.insert %[[C0]], %[[VEC4]] [4] : i32 into vector<8xi32>
-        // CHECK: %[[VEC6:.*]] = vector.insert %[[C0]], %[[VEC5]] [5] : i32 into vector<8xi32>
-        %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x64xf32> -> !xegpu.tensor_desc<32xf32>
-        // CHECK: %[[VEC7:.*]] = vector.bitcast %[[VEC6]] : vector<8xi32> to vector<4xi64>
-        // CHECK: %[[EXTR:.*]] = vector.extract %[[VEC7]][0] : i64 from vector<4xi64>
-        // CHECK: %[[ADDR:.*]] = arith.addi %[[EXTR]], %[[C384]] : i64
+        %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<512xf32> -> !xegpu.tensor_desc<32xf32>
+        // CHECK: %[[ADDR:.*]] = arith.addi %[[INTPTR_I64]], %[[C384]] : i64
         // CHECK: %[[PTR:.*]] = llvm.inttoptr %[[ADDR]] : i64 to !llvm.ptr<1>
         // CHECK: %[[LOAD:.*]] = xevm.blockload %[[PTR]] <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}>
         // CHECK-SAME: : (!llvm.ptr<1>) -> vector<2xi32>
-        %loaded = xegpu.load_nd %src_tdesc[1, 32] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+        %loaded = xegpu.load_nd %src_tdesc[96] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
             : !xegpu.tensor_desc<32xf32> -> vector<2xf32>
 
-        // CHECK: %[[INTPTR1:.*]] = memref.extract_aligned_pointer_as_index %[[DSTTE]] : memref<8x32xf32> -> index
+        // CHECK: %[[INTPTR1:.*]] = memref.extract_aligned_pointer_as_index %[[DSTTE]] : memref<256xf32> -> index
         // CHECK: %[[INTPTR1_I64:.*]] = arith.index_castui %[[INTPTR1]] : index to i64
-        // CHECK: %[[VEC1_1:.*]] = vector.insert %[[INTPTR1_I64]], %[[CST]] [0] : i64 into vector<4xi64>
-        // CHECK: %[[VEC2_1:.*]] = vector.bitcast %[[VEC1_1]] : vector<4xi64> to vector<8xi32>
-        // CHECK: %[[VEC3_1:.*]] = vector.insert %[[C32]], %[[VEC2_1]] [2] : i32 into vector<8xi32>
-        // CHECK: %[[VEC4_1:.*]] = vector.insert %[[C8]], %[[VEC3_1]] [3] : i32 into vector<8xi32>
-        // CHECK: %[[VEC5_1:.*]] = vector.insert %[[C0]], %[[VEC4_1]] [4] : i32 into vector<8xi32>
-        // CHECK: %[[VEC6_1:.*]] = vector.insert %[[C0]], %[[VEC5_1]] [5] : i32 into vector<8xi32>
-        %dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<8x32xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.block_tdesc_attr<memory_space = global>>
-        // CHECK: %[[VEC7_1:.*]] = vector.bitcast %[[VEC6_1]] : vector<8xi32> to vector<4xi64>
-        // CHECK: %[[EXTR1:.*]] = vector.extract %[[VEC7_1]][0] : i64 from vector<4xi64>
-        // CHECK: %[[ADDR1:.*]] = arith.addi %[[EXTR1]], %[[C512]] : i64
+        %dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<256xf32> -> !xegpu.tensor_desc<32xf32, #xegpu.block_tdesc_attr<memory_space = global>>
+        // CHECK: %[[ADDR1:.*]] = arith.addi %[[INTPTR1_I64]], %[[C512]] : i64
         // CHECK: %[[PTR1:.*]] = llvm.inttoptr %[[ADDR1]] : i64 to !llvm.ptr<1>
         // CHECK: xevm.blockstore %[[PTR1]], %[[LOAD]] <{cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>}>
         // CHECK-SAME: : (!llvm.ptr<1>, vector<2xi32>)
-        xegpu.store_nd %loaded, %dst_tdesc[4, 0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
+        xegpu.store_nd %loaded, %dst_tdesc[128] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
             : vector<2xf32>, !xegpu.tensor_desc<32xf32, #xegpu.block_tdesc_attr<memory_space = global>>
         gpu.return
     }
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
index 4c6bbf25b4728..95774ca67c4f2 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
@@ -16,6 +16,7 @@ gpu.module @load_store_check {
         %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
 
 
+        //CHECK: %[[LD_SIZEOF_F32:.*]] = arith.constant 4 : i32
         //CHECK: %[[LD_DESC_I64:.*]] = vector.bitcast %[[LD_DESC]] : vector<8xi32> to vector<4xi64>
         //CHECK: %[[LD_INTPTR:.*]] = vector.extract %[[LD_DESC_I64]][0] : i64 from vector<4xi64>
         //CHECK: %[[LD_BASE_W:.*]] = vector.extract %[[LD_DESC]][2] : i32 from vector<8xi32>
@@ -25,7 +26,6 @@ gpu.module @load_store_check {
         //CHECK: %[[LD_TILE_H64:.*]] = arith.constant 0 : i64
         //CHECK: %[[LD_TILE_H:.*]] = arith.trunci %[[LD_TILE_H64]] : i64 to i32
         //CHECK: %[[LD_LLVMPTR:.*]] = llvm.inttoptr %[[LD_INTPTR]] : i64 to !llvm.ptr<1>
-        //CHECK: %[[LD_SIZEOF_F32:.*]] = arith.constant 4 : i32
         //CHECK: %[[LD_BASE_ROW_IN_BYTES:.*]] = arith.muli %[[LD_BASE_W]], %[[LD_SIZEOF_F32]] : i32
         //CHECK: %[[LD_LOADED_I32:.*]] = xevm.blockload2d %[[LD_LLVMPTR]], %[[LD_BASE_ROW_IN_BYTES]],
         //CHECK-SAME: %[[LD_BASE_H]], %[[LD_BASE_ROW_IN_BYTES]], %[[LD_TILE_W]], %[[LD_TILE_H]]
@@ -52,6 +52,7 @@ gpu.module @load_store_check {
         // CHECK: %[[DESC:.*]] = vector.insert {{.*}}, %[[DESC_4]] [5] : i32 into vector<8xi32>
         %dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
 
+        //CHECK: %[[SIZEOF_F32:.*]] = arith.constant 4 : i32
         //CHECK: %[[DESC_I64:.*]] = vector.bitcast %[[DESC]] : vector<8xi32> to vector<4xi64>
         //CHECK: %[[INTPTR:.*]] = vector.extract %[[DESC_I64]][0] : i64 from vector<4xi64>
         //CHECK: %[[BASE_W:.*]] = vector.extract %[[DESC]][2] : i32 from vector<8xi32>
@@ -61,7 +62,6 @@ gpu.module @load_store_check {
         //CHECK: %[[TILE_H64:.*]] = arith.constant 0 : i64
         //CHECK: %[[TILE_H:.*]] = arith.trunci %[[TILE_H64]] : i64 to i32
         //CHECK: %[[LLVMPTR:.*]] = llvm.inttoptr %[[INTPTR]] : i64 to !llvm.ptr<1>
-        //CHECK: %[[SIZEOF_F32:.*]] = arith.constant 4 : i32
         //CHECK: %[[BASE_ROW_IN_BYTES:.*]] = arith.muli %[[BASE_W]], %[[SIZEOF_F32]] : i32
         //CHECK: %[[FLAT_VALUE_I32:.*]] = vector.bitcast %[[LOADED_F32_MODIFIED]] : vector<8xf32> to vector<8xi32>
         //CHECK: xevm.blockstore2d %[[LLVMPTR]], %[[BASE_ROW_IN_BYTES]], %[[BASE_H]], %[[BASE_ROW_IN_BYTES]],
diff --git a/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir
index 873478aed57e3..ac13ae9013593 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir
@@ -16,6 +16,7 @@ gpu.module @fence_check {
         %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32,
             #xegpu.block_tdesc_attr<memory_space = global>, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
 
+        //CHECK: %[[PREF_SIZEOF_F32:.*]] = arith.constant 4 : i32
         //CHECK: %[[LD_DESC_I64:.*]] = vector.bitcast %[[LD_DESC]] : vector<8xi32> to vector<4xi64>
         //CHECK: %[[PREF_INTPTR:.*]] = vector.extract %[[LD_DESC_I64]][0] : i64 from vector<4xi64>
         //CHECK: %[[PREF_BASE_W:.*]] = vector.extract %[[LD_DESC]][2] : i32 from vector<8xi32>
@@ -25,7 +26,6 @@ gpu.module @fence_check {
         //CHECK: %[[PREF_TILE_H64:.*]] = arith.constant 0 : i64
         //CHECK: %[[PREF_TILE_H:.*]] = arith.trunci %[[PREF_TILE_H64]] : i64 to i32
         //CHECK: %[[PREF_LLVMPTR:.*]] = llvm.inttoptr %[[PREF_INTPTR]] : i64 to !llvm.ptr<1>
-        //CHECK: %[[PREF_SIZEOF_F32:.*]] = arith.constant 4 : i32
         //CHECK: %[[PREF_BASE_ROW_IN_BYTES:.*]] = arith.muli %[[PREF_BASE_W]], %[[PREF_SIZEOF_F32]] : i32
         //CHECK: xevm.blockprefetch2d %[[PREF_LLVMPTR]], %[[PREF_BASE_ROW_IN_BYTES]], %[[PREF_BASE_H]],
         //CHECK-SAME:   %[[PREF_BASE_ROW_IN_BYTES]], %[[PREF_TILE_W]], %[[PREF_TILE_H]]



More information about the Mlir-commits mailing list