[Mlir-commits] [mlir] 447af32 - [MLIR][XeGPU][XeVM] create_nd_tdesc: use correct pitch from strides. (#170384)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Dec 8 08:15:48 PST 2025


Author: Sang Ik Lee
Date: 2025-12-08T08:15:44-08:00
New Revision: 447af32fbb8c9522bc01948caadd1ae562c01373

URL: https://github.com/llvm/llvm-project/commit/447af32fbb8c9522bc01948caadd1ae562c01373
DIFF: https://github.com/llvm/llvm-project/commit/447af32fbb8c9522bc01948caadd1ae562c01373.diff

LOG: [MLIR][XeGPU][XeVM] create_nd_tdesc: use correct pitch from strides. (#170384)

Base memory pitch should be derived from base stride, not base width.
Remove offset fields from tensor descriptor payload and add pitch field.

Added: 
    

Modified: 
    mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
    mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
    mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
    mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 7f1ec17ce0ae8..9c99a24bea8cd 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -50,11 +50,10 @@ static constexpr int32_t executionSize{16};
 
 // Offsets to individual fields of the 8xi32 layout nd tensor descriptor.
 enum class NdTdescOffset : uint32_t {
-  BasePtr = 0,       // Base pointer (i64)
-  BaseShapeW = 2,    // Base shape width (i32)
-  BaseShapeH = 3,    // Base shape height (i32)
-  TensorOffsetW = 4, // Tensor offset W (i32)
-  TensorOffsetH = 5  // Tensor offset H (i32)
+  BasePtr = 0,    // Base pointer (i64)
+  BaseShapeW = 2, // Base shape width (i32)
+  BaseShapeH = 3, // Base shape height (i32)
+  BasePitch = 4,  // Base pitch (i32)
 };
 
 static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
@@ -179,11 +178,10 @@ class CreateNdDescToXeVMPattern
     Value baseAddr;
     Value baseShapeW;
     Value baseShapeH;
-    Value offsetW;
-    Value offsetH;
 
     // Source can be a memref or a pointer (ui64, ui32, i64 or i32).
     SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
+    SmallVector<OpFoldResult> mixedStrides = op.getMixedStrides();
     // Descriptor shape is expected to be 2D.
     int64_t rank = mixedSizes.size();
     auto sourceTy = source.getType();
@@ -216,12 +214,11 @@ class CreateNdDescToXeVMPattern
       val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val);
       return val;
     };
-    // Offsets are not supported (0 is used).
-    offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
-    offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
     // Get shape values from op fold results.
     baseShapeW = createOffset(mixedSizes, 1);
     baseShapeH = createOffset(mixedSizes, 0);
+    // Get pitch value from op fold results.
+    Value basePitch = createOffset(mixedStrides, 0);
     // Populate payload.
     Value payLoadAsI64 =
         vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
@@ -235,12 +232,9 @@ class CreateNdDescToXeVMPattern
     payload =
         vector::InsertOp::create(rewriter, loc, baseShapeH, payload,
                                  static_cast<int>(NdTdescOffset::BaseShapeH));
-    payload = vector::InsertOp::create(
-        rewriter, loc, offsetW, payload,
-        static_cast<int>(NdTdescOffset::TensorOffsetW));
-    payload = vector::InsertOp::create(
-        rewriter, loc, offsetH, payload,
-        static_cast<int>(NdTdescOffset::TensorOffsetH));
+    payload =
+        vector::InsertOp::create(rewriter, loc, basePitch, payload,
+                                 static_cast<int>(NdTdescOffset::BasePitch));
     rewriter.replaceOp(op, payload);
     return success();
   }
@@ -289,6 +283,8 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
           rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW));
       Value baseShapeH = vector::ExtractOp::create(
           rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH));
+      Value basePitch = vector::ExtractOp::create(
+          rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BasePitch));
       // Offsets are provided by the op.
       // convert them to i32.
       Value offsetW =
@@ -303,8 +299,11 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
       Value basePtrLLVM =
           LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
       // Compute width in bytes.
-      Value surfaceW =
+      Value baseWidthByte =
           arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
+      // Compute pitch in bytes.
+      Value basePitchByte =
+          arith::MulIOp::create(rewriter, loc, basePitch, elemByteSize);
 
       // Get tile width from the tensor descriptor type.
       auto tileW = tdescTy.getDimSize(tileRank - 1);
@@ -331,8 +330,8 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
         auto storeCacheControl =
             translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
         xevm::BlockStore2dOp::create(
-            rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
-            offsetH, elemBitSize, tileW, tileH, src,
+            rewriter, loc, basePtrLLVM, baseWidthByte, baseShapeH,
+            basePitchByte, offsetW, offsetH, elemBitSize, tileW, tileH, src,
             xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
         rewriter.eraseOp(op);
       } else {
@@ -340,9 +339,9 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
             translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
         if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
           xevm::BlockPrefetch2dOp::create(
-              rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW,
-              offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
-              xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
+              rewriter, loc, basePtrLLVM, baseWidthByte, baseShapeH,
+              basePitchByte, offsetW, offsetH, elemBitSize, tileW, tileH,
+              vblocks, xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
           rewriter.eraseOp(op);
         } else {
           VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
@@ -355,9 +354,9 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
                              : rewriter.getIntegerType(elemBitSize));
 
           Value resultFlatVec = xevm::BlockLoad2dOp::create(
-              rewriter, loc, loadedTy, basePtrLLVM, surfaceW, baseShapeH,
-              surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
-              transpose, vnni,
+              rewriter, loc, loadedTy, basePtrLLVM, baseWidthByte, baseShapeH,
+              basePitchByte, offsetW, offsetH, elemBitSize, tileW, tileH,
+              vblocks, transpose, vnni,
               xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
           resultFlatVec = vector::BitCastOp::create(
               rewriter, loc,

diff  --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
index 8b87b791c9fd3..9a1e2cb3c7de0 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
@@ -8,21 +8,19 @@ gpu.module @create_nd_tdesc {
   gpu.func @create_nd_tdesc(%src: memref<16x32xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index,
   %stride1: index, %stride2: index, %offset1: index, %offset2: index, %dyn: memref<?x?xf16>) kernel {
         // CHECK: %[[INTPTR_5:.*]] = memref.extract_aligned_pointer_as_index %[[DYN]] : memref<?x?xf16> -> index
-        // CHECK: %[[VAR23:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64
+        // CHECK: %[[DYN_ADDR:.*]] = arith.index_castui %[[INTPTR_5]] : index to i64
         // CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index
         // CHECK: %[[BASE_ADDR:.*]] = arith.index_castui %[[VAR0]] : index to i64
         // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
-        // CHECK: %[[OFFSET_W:.*]] = arith.constant 0 : i32
-        // CHECK: %[[OFFSET_H:.*]] = arith.constant 0 : i32
         // CHECK: %[[SHAPE_W:.*]] = arith.index_cast %[[ARG3]] : index to i32
         // CHECK: %[[SHAPE_H:.*]] = arith.index_cast %[[ARG2]] : index to i32
+        // CHECK: %[[PITCH:.*]] = arith.index_cast %[[ARG4]] : index to i32
         // CHECK: %[[VAR6:.*]] = vector.bitcast %[[CST]] : vector<8xi32> to vector<4xi64>
         // CHECK: %[[VAR7:.*]] = vector.insert %[[BASE_ADDR]], %[[VAR6]] [0] : i64 into vector<4xi64>
         // CHECK: %[[VAR8:.*]] = vector.bitcast %[[VAR7]] : vector<4xi64> to vector<8xi32>
         // CHECK: %[[VAR9:.*]] = vector.insert %[[SHAPE_W]], %[[VAR8]] [2] : i32 into vector<8xi32>
         // CHECK: %[[VAR10:.*]] = vector.insert %[[SHAPE_H]], %[[VAR9]] [3] : i32 into vector<8xi32>
-        // CHECK: %[[VAR11:.*]] = vector.insert %[[OFFSET_W]], %[[VAR10]] [4] : i32 into vector<8xi32>
-        // CHECK: %[[VAR12:.*]] = vector.insert %[[OFFSET_H]], %[[VAR11]] [5] : i32 into vector<8xi32>
+        // CHECK: %[[VAR11:.*]] = vector.insert %[[PITCH]], %[[VAR10]] [4] : i32 into vector<8xi32>
         %ptr_tdesc = xegpu.create_nd_tdesc %ptr, shape:[%shape1, %shape2], strides:[%stride1, %stride2]
             : ui64 -> !xegpu.tensor_desc<8x16xf32>
 
@@ -32,19 +30,18 @@ gpu.module @create_nd_tdesc {
         // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32> -> index
         // CHECK: %[[BASE_ADDR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
         // CHECK: %[[CST_1:.*]] = arith.constant dense<0> : vector<8xi32>
-        // CHECK: %[[OFFSET_W2:.*]] = arith.constant 0 : i32
-        // CHECK: %[[OFFSET_H2:.*]] = arith.constant 0 : i32
         // CHECK: %[[C32_I64:.*]] = arith.constant 32 : i64
         // CHECK: %[[SHAPE_W2:.*]] = arith.trunci %[[C32_I64]] : i64 to i32
         // CHECK: %[[C16_I64:.*]] = arith.constant 16 : i64
         // CHECK: %[[SHAPE_H2:.*]] = arith.trunci %[[C16_I64]] : i64 to i32
+        // CHECK: %[[C32_I64_2:.*]] = arith.constant 32 : i64
+        // CHECK: %[[PITCH2:.*]] = arith.trunci %[[C32_I64_2]] : i64 to i32
         // CHECK: %[[VAR14:.*]] = vector.bitcast %[[CST_1]] : vector<8xi32> to vector<4xi64>
         // CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2]], %[[VAR14]] [0] : i64 into vector<4xi64>
         // CHECK: %[[VAR16:.*]] = vector.bitcast %[[VAR15]] : vector<4xi64> to vector<8xi32>
         // CHECK: %[[VAR17:.*]] = vector.insert %[[SHAPE_W2]], %[[VAR16]] [2] : i32 into vector<8xi32>
         // CHECK: %[[VAR18:.*]] = vector.insert %[[SHAPE_H2]], %[[VAR17]] [3] : i32 into vector<8xi32>
-        // CHECK: %[[VAR19:.*]] = vector.insert %[[OFFSET_W2]], %[[VAR18]] [4] : i32 into vector<8xi32>
-        // CHECK: %[[PAYLOAD:.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32>
+        // CHECK: %[[VAR19:.*]] = vector.insert %[[PITCH2]], %[[VAR18]] [4] : i32 into vector<8xi32>
         %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
 
         // CHECK: %[[C1:.*]] = arith.constant 1 : index
@@ -53,18 +50,16 @@ gpu.module @create_nd_tdesc {
         %size_x = arith.constant 64 : index
         // CHECK: %[[C16:.*]] = arith.constant 16 : index
         %BLOCK_DMODEL = arith.constant 16 : index
-        // CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32>
-        // CHECK: %[[C0_I32_6:.*]] = arith.constant 0 : i32
-        // CHECK: %[[C0_I32_7:.*]] = arith.constant 0 : i32
-        // CHECK: %[[VAR21:.*]] = arith.index_cast %[[C16]] : index to i32
-        // CHECK: %[[VAR22:.*]] = arith.index_cast %[[C64]] : index to i32
-        // CHECK: %[[VAR24:.*]] = vector.bitcast %[[CST_4]] : vector<8xi32> to vector<4xi64>
-        // CHECK: %[[VAR25:.*]] = vector.insert %[[VAR23]], %[[VAR24]] [0] : i64 into vector<4xi64>
-        // CHECK: %[[VAR26:.*]] = vector.bitcast %[[VAR25]] : vector<4xi64> to vector<8xi32>
-        // CHECK: %[[VAR27:.*]] = vector.insert %[[VAR21]], %[[VAR26]] [2] : i32 into vector<8xi32>
-        // CHECK: %[[VAR28:.*]] = vector.insert %[[VAR22]], %[[VAR27]] [3] : i32 into vector<8xi32>
-        // CHECK: %[[VAR29:.*]] = vector.insert %[[C0_I32_6]], %[[VAR28]] [4] : i32 into vector<8xi32>
-        // CHECK: %[[VAR30:.*]] = vector.insert %[[C0_I32_7]], %[[VAR29]] [5] : i32 into vector<8xi32>
+        // CHECK: %[[CST_3:.*]] = arith.constant dense<0> : vector<8xi32>
+        // CHECK: %[[SHAPE_W3:.*]] = arith.index_cast %[[C16]] : index to i32
+        // CHECK: %[[SHAPE_H3:.*]] = arith.index_cast %[[C64]] : index to i32
+        // CHECK: %[[PITCH3:.*]] = arith.index_cast %[[C16]] : index to i32
+        // CHECK: %[[VAR25:.*]] = vector.bitcast %[[CST_3]] : vector<8xi32> to vector<4xi64>
+        // CHECK: %[[VAR26:.*]] = vector.insert %[[DYN_ADDR]], %[[VAR25]] [0] : i64 into vector<4xi64>
+        // CHECK: %[[VAR27:.*]] = vector.bitcast %[[VAR26]] : vector<4xi64> to vector<8xi32>
+        // CHECK: %[[VAR28:.*]] = vector.insert %[[SHAPE_W3]], %[[VAR27]] [2] : i32 into vector<8xi32>
+        // CHECK: %[[VAR29:.*]] = vector.insert %[[SHAPE_H3]], %[[VAR28]] [3] : i32 into vector<8xi32>
+        // CHECK: %[[VAR30:.*]] = vector.insert %[[PITCH3]], %[[VAR29]] [4] : i32 into vector<8xi32>
         %dyn_tdesc  = xegpu.create_nd_tdesc %dyn, shape: [%size_x, %BLOCK_DMODEL], strides: [%BLOCK_DMODEL, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<16x16xf16>
         gpu.return
     }

diff  --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
index afeae8be24b72..4c73c9c238b6e 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
@@ -1,78 +1,32 @@
-// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
+// RUN: mlir-opt -convert-xegpu-to-xevm -canonicalize %s | FileCheck %s
 
 gpu.module @load_store_check {
     // CHECK-LABEL: gpu.func @load_store(
-    // CHECK-SAME: %[[ARG0:.*]]: memref<8x16xf32, 1>, %[[ARG1:.*]]: memref<8x16xf32, 1>) kernel {
     gpu.func @load_store(%src: memref<8x16xf32, 1>, %dst: memref<8x16xf32, 1>) kernel {
+        // CHECK: %[[W_P_BYTES:.*]] = arith.constant 64 : i32
+        // CHECK: %[[ZERO:.*]] = arith.constant 0 : i32
+        // CHECK: %[[H:.*]] = arith.constant 8 : i32
         %srcce = memref.memory_space_cast %src : memref<8x16xf32, 1> to memref<8x16xf32>
         %dstte = memref.memory_space_cast %dst : memref<8x16xf32, 1> to memref<8x16xf32>
 
-        // CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<8x16xf32, 1> to memref<8x16xf32>
-        // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<8x16xf32> -> index
-        // CHECK: %[[LD_PTR_AS_I64:.*]] = arith.index_castui %[[INTPTR]] : index to i64
-        // CHECK: %[[MEMSPACECAST_0:.*]] = memref.memory_space_cast %[[ARG1]] : memref<8x16xf32, 1> to memref<8x16xf32>
-        // CHECK: %[[INTPTR_1:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST_0]] : memref<8x16xf32> -> index
-        // CHECK: %[[ST_PTR_AS_I64:.*]] = arith.index_castui %[[INTPTR_1]] : index to i64
-        // CHECK: %[[LD_CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64>
-        // CHECK: %[[LD_DESC_0:.*]] = vector.insert %[[LD_PTR_AS_I64]], %[[LD_CREATE_DESC_I64]] [0] : i64 into vector<4xi64>
-        // CHECK: %[[LD_DESC_1:.*]] = vector.bitcast %[[LD_DESC_0]] : vector<4xi64> to vector<8xi32>
-        // CHECK: %[[LD_DESC_2:.*]] = vector.insert {{.*}}, %[[LD_DESC_1]] [2] : i32 into vector<8xi32>
-        // CHECK: %[[LD_DESC_3:.*]] = vector.insert {{.*}}, %[[LD_DESC_2]] [3] : i32 into vector<8xi32>
-        // CHECK: %[[LD_DESC_4:.*]] = vector.insert {{.*}}, %[[LD_DESC_3]] [4] : i32 into vector<8xi32>
-        // CHECK: %[[LD_DESC:.*]] = vector.insert {{.*}}, %[[LD_DESC_4]] [5] : i32 into vector<8xi32>
         %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
 
-
-        //CHECK: %[[LD_SIZEOF_F32:.*]] = arith.constant 4 : i32
-        //CHECK: %[[LD_DESC_I64:.*]] = vector.bitcast %[[LD_DESC]] : vector<8xi32> to vector<4xi64>
-        //CHECK: %[[LD_INTPTR:.*]] = vector.extract %[[LD_DESC_I64]][0] : i64 from vector<4xi64>
-        //CHECK: %[[LD_BASE_W:.*]] = vector.extract %[[LD_DESC]][2] : i32 from vector<8xi32>
-        //CHECK: %[[LD_BASE_H:.*]] = vector.extract %[[LD_DESC]][3] : i32 from vector<8xi32>
-        //CHECK: %[[LD_TILE_W64:.*]] = arith.constant 0 : i64
-        //CHECK: %[[LD_TILE_W:.*]] = arith.trunci %[[LD_TILE_W64]] : i64 to i32
-        //CHECK: %[[LD_TILE_H64:.*]] = arith.constant 0 : i64
-        //CHECK: %[[LD_TILE_H:.*]] = arith.trunci %[[LD_TILE_H64]] : i64 to i32
-        //CHECK: %[[LD_LLVMPTR:.*]] = llvm.inttoptr %[[LD_INTPTR]] : i64 to !llvm.ptr<1>
-        //CHECK: %[[LD_BASE_ROW_IN_BYTES:.*]] = arith.muli %[[LD_BASE_W]], %[[LD_SIZEOF_F32]] : i32
-        //CHECK: %[[LD_LOADED_I32:.*]] = xevm.blockload2d %[[LD_LLVMPTR]], %[[LD_BASE_ROW_IN_BYTES]],
-        //CHECK-SAME: %[[LD_BASE_H]], %[[LD_BASE_ROW_IN_BYTES]], %[[LD_TILE_W]], %[[LD_TILE_H]]
+        //CHECK: %[[LD_LOADED_I32:.*]] = xevm.blockload2d %{{.*}}, %[[W_P_BYTES]], %[[H]], %[[W_P_BYTES]], %[[ZERO]], %[[ZERO]]
         //CHECK-SAME: <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
         //CHECK-SAME:   pack_register = false, tile_height = 8 : i32, tile_width = 16 : i32, transpose = false,
         //CHECK-SAME:   v_blocks = 1 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
         %loaded = xegpu.load_nd %src_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
             : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
-        //CHECK: %[[LD_LOADED_F32:.*]] = vector.bitcast %[[LD_LOADED_I32]] : vector<8xi32> to vector<8xf32>
 
         %tid_x = gpu.thread_id x
         %tid_x_i32 = arith.index_cast %tid_x : index to i32
         %tid_x_f32 = arith.sitofp %tid_x_i32 : i32 to f32
-        //CHECK: %[[LOADED_F32_MODIFIED:.*]] = vector.insert %{{.*}}, %[[LD_LOADED_F32]] [0] : f32 into vector<8xf32>
         %loaded_modified = vector.insert %tid_x_f32, %loaded[0] : f32 into vector<8xf32>
 
-        // CHECK: %[[CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64>
-        // CHECK: %[[DESC_0:.*]] = vector.insert %[[ST_PTR_AS_I64]], %[[CREATE_DESC_I64]] [0] : i64 into vector<4xi64>
-        // CHECK: %[[DESC_1:.*]] = vector.bitcast %[[DESC_0]] : vector<4xi64> to vector<8xi32>
-        // CHECK: %[[DESC_2:.*]] = vector.insert {{.*}}, %[[DESC_1]] [2] : i32 into vector<8xi32>
-        // CHECK: %[[DESC_3:.*]] = vector.insert {{.*}}, %[[DESC_2]] [3] : i32 into vector<8xi32>
-        // CHECK: %[[DESC_4:.*]] = vector.insert {{.*}}, %[[DESC_3]] [4] : i32 into vector<8xi32>
-        // CHECK: %[[DESC:.*]] = vector.insert {{.*}}, %[[DESC_4]] [5] : i32 into vector<8xi32>
         %dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
 
-        //CHECK: %[[SIZEOF_F32:.*]] = arith.constant 4 : i32
-        //CHECK: %[[DESC_I64:.*]] = vector.bitcast %[[DESC]] : vector<8xi32> to vector<4xi64>
-        //CHECK: %[[INTPTR:.*]] = vector.extract %[[DESC_I64]][0] : i64 from vector<4xi64>
-        //CHECK: %[[BASE_W:.*]] = vector.extract %[[DESC]][2] : i32 from vector<8xi32>
-        //CHECK: %[[BASE_H:.*]] = vector.extract %[[DESC]][3] : i32 from vector<8xi32>
-        //CHECK: %[[TILE_W64:.*]] = arith.constant 0 : i64
-        //CHECK: %[[TILE_W:.*]] = arith.trunci %[[TILE_W64]] : i64 to i32
-        //CHECK: %[[TILE_H64:.*]] = arith.constant 0 : i64
-        //CHECK: %[[TILE_H:.*]] = arith.trunci %[[TILE_H64]] : i64 to i32
-        //CHECK: %[[LLVMPTR:.*]] = llvm.inttoptr %[[INTPTR]] : i64 to !llvm.ptr<1>
-        //CHECK: %[[BASE_ROW_IN_BYTES:.*]] = arith.muli %[[BASE_W]], %[[SIZEOF_F32]] : i32
-        //CHECK: %[[FLAT_VALUE_I32:.*]] = vector.bitcast %[[LOADED_F32_MODIFIED]] : vector<8xf32> to vector<8xi32>
-        //CHECK: xevm.blockstore2d %[[LLVMPTR]], %[[BASE_ROW_IN_BYTES]], %[[BASE_H]], %[[BASE_ROW_IN_BYTES]],
-        //CHECK-SAME: %[[TILE_W]], %[[TILE_H]], %[[FLAT_VALUE_I32]]
-        //CHECK-SAME: <{cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
+        //CHECK: xevm.blockstore2d %{{.*}}, %[[W_P_BYTES]], %[[H]], %[[W_P_BYTES]], %[[ZERO]], %[[ZERO]], %{{.*}} <{
+        //CHECK-SAME:   cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
         //CHECK-SAME:   tile_height = 8 : i32, tile_width = 16 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
         xegpu.store_nd %loaded_modified, %dst_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
             : vector<8xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>

diff  --git a/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir
index e4b303087ea9b..43df721fb77a0 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir
@@ -3,27 +3,16 @@
 gpu.module @prefetch_nd_check {
     // CHECK-LABEL: gpu.func @prefetch_nd
     gpu.func @prefetch_nd(%src: memref<8x16xf32, 1>, %dst: memref<8x16xf32, 1>) kernel {
-        // CHECK: %[[PREF_BASE_ROW_IN_BYTES:.*]] = arith.constant 64 : i32
-        // CHECK: %[[LD_CREATE_DESC_I64:.*]] = arith.constant dense<0> : vector<4xi64>
-        // CHECK: %[[PREF_BASE_H:.*]] = arith.constant 8 : i32
-        // CHECK: %[[PREF_BASE_W:.*]] = arith.constant 16 : i32
+        // CHECK: %[[BASE_WIDTH_PITCH_BYTES:.*]] = arith.constant 64 : i32
         // CHECK: %[[OFFSET_ZERO:.*]] = arith.constant 0 : i32
+        // CHECK: %[[BASE_H:.*]] = arith.constant 8 : i32
         %srcce = memref.memory_space_cast %src : memref<8x16xf32, 1> to memref<8x16xf32>
-        // CHECK: %[[LD_PTR_AS_I64:.*]] = arith.index_castui {{.*}} : index to i64
-        // CHECK: %[[LD_DESC_0:.*]] = vector.insert %[[LD_PTR_AS_I64]], %[[LD_CREATE_DESC_I64]] [0] : i64 into vector<4xi64>
-        // CHECK: %[[LD_DESC_1:.*]] = vector.bitcast %[[LD_DESC_0]] : vector<4xi64> to vector<8xi32>
-        // CHECK: %[[LD_DESC_2:.*]] = vector.insert %[[PREF_BASE_W]], %[[LD_DESC_1]] [2] : i32 into vector<8xi32>
-        // CHECK: %[[LD_DESC_3:.*]] = vector.insert %[[PREF_BASE_H]], %[[LD_DESC_2]] [3] : i32 into vector<8xi32>
-        // CHECK: %[[LD_DESC_4:.*]] = vector.insert %[[OFFSET_ZERO]], %[[LD_DESC_3]] [4] : i32 into vector<8xi32>
-        // CHECK: %[[LD_DESC:.*]] = vector.insert %[[OFFSET_ZERO]], %[[LD_DESC_4]] [5] : i32 into vector<8xi32>
         %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32,
             #xegpu.block_tdesc_attr<memory_space = global>, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
 
-        //CHECK: %[[LD_DESC_I64:.*]] = vector.bitcast %[[LD_DESC]] : vector<8xi32> to vector<4xi64>
-        //CHECK: %[[PREF_INTPTR:.*]] = vector.extract %[[LD_DESC_I64]][0] : i64 from vector<4xi64>
-        //CHECK: %[[PREF_LLVMPTR:.*]] = llvm.inttoptr %[[PREF_INTPTR]] : i64 to !llvm.ptr<1>
-        //CHECK: xevm.blockprefetch2d %[[PREF_LLVMPTR]], %[[PREF_BASE_ROW_IN_BYTES]], %[[PREF_BASE_H]],
-        //CHECK-SAME:   %[[PREF_BASE_ROW_IN_BYTES]], %[[OFFSET_ZERO]], %[[OFFSET_ZERO]]
+        //CHECK: %[[LLVMPTR:.*]] = llvm.inttoptr %{{.*}} : i64 to !llvm.ptr<1>
+        //CHECK: xevm.blockprefetch2d %[[LLVMPTR]], %[[BASE_WIDTH_PITCH_BYTES]], %[[BASE_H]],
+        //CHECK-SAME:   %[[BASE_WIDTH_PITCH_BYTES]], %[[OFFSET_ZERO]], %[[OFFSET_ZERO]]
         //CHECK-SAME:   <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
         //CHECK-SAME:     tile_height = 8 : i32, tile_width = 16 : i32, v_blocks = 1 : i32}>
         //CHECK-SAME:   : (!llvm.ptr<1>, i32, i32, i32, i32, i32)


        


More information about the Mlir-commits mailing list