[Mlir-commits] [mlir] [MLIR][Conversion] Convert XeGPU to XeVM pass: Remove lowering support for tensor descriptor with offsets. (PR #157550)

Sang Ik Lee llvmlistbot at llvm.org
Mon Sep 8 13:49:41 PDT 2025


https://github.com/silee2 created https://github.com/llvm/llvm-project/pull/157550

And update load/store/prefetch test cases to use direct offsets.
Tensor descriptors with offsets are getting deprecated.

>From 68d16070e49d8b69ff4d359ff13f56a3d67144ec Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Mon, 8 Sep 2025 20:45:26 +0000
Subject: [PATCH] Remove lowering support for tensor descriptor with offsets.
 Update load/store/prefetch test cases to use direct offsets.

---
 .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp    | 170 +++---------
 .../XeGPUToXeVM/create_nd_tdesc.mlir          |  32 ---
 .../XeGPUToXeVM/loadstoreprefetch.mlir        | 244 +++---------------
 .../XeGPUToXeVM/materializecast.mlir          |  38 +--
 .../Conversion/XeGPUToXeVM/update_offset.mlir |  25 --
 5 files changed, 85 insertions(+), 424 deletions(-)
 delete mode 100644 mlir/test/Conversion/XeGPUToXeVM/update_offset.mlir

diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index a7f2dc2d6a43e..7d756620ecd16 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -154,6 +154,9 @@ class CreateNdDescToXeVMPattern
   matchAndRewrite(xegpu::CreateNdDescOp op,
                   xegpu::CreateNdDescOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
+    if (mixedOffsets.size() != 0)
+      return rewriter.notifyMatchFailure(op, "Offsets not supported.");
     auto loc = op.getLoc();
     auto source = op.getSource();
     // Op is lowered to a code sequence that populates payload.
@@ -177,7 +180,6 @@ class CreateNdDescToXeVMPattern
 
     // Source can be a memref or a pointer (ui64, ui32, i64 or i32).
     SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
-    SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
     // Descriptor shape is expected to be 2D.
     int64_t rank = mixedSizes.size();
     if (rank != 2)
@@ -202,17 +204,9 @@ class CreateNdDescToXeVMPattern
       val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val);
       return val;
     };
-    // Offsets can be either 2D or not provided (0 is used).
-    if (mixedOffsets.size() == 2) {
-      offsetW = createOffset(mixedOffsets, 1);
-      offsetH = createOffset(mixedOffsets, 0);
-    } else if (mixedOffsets.size() == 0) {
-      offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
-      offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
-    } else {
-      return rewriter.notifyMatchFailure(op,
-                                         "Expected 2D offsets or no offsets.");
-    }
+    // Offsets are not supported not (0 is used).
+    offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
+    offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
     // Get shape values from op fold results.
     baseShapeW = createOffset(mixedSizes, 1);
     baseShapeH = createOffset(mixedSizes, 0);
@@ -247,39 +241,6 @@ class CreateNdDescToXeVMPattern
   }
 };
 
-class UpdateNdOffsetToXeVMPattern
-    : public OpConversionPattern<xegpu::UpdateNdOffsetOp> {
-  using OpConversionPattern::OpConversionPattern;
-  LogicalResult
-  matchAndRewrite(xegpu::UpdateNdOffsetOp op,
-                  xegpu::UpdateNdOffsetOp::Adaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    auto loc = op.getLoc();
-    auto mixedOffsets = op.getMixedOffsets();
-    // Only 2D offsets are supported for now.
-    if (mixedOffsets.size() != 2)
-      return rewriter.notifyMatchFailure(op, "Expected 2D offsets.");
-    auto payload = adaptor.getTensorDesc();
-    // Utility for updating payload offset values from op fold result.
-    auto updateOffset = [&](unsigned idx, int payloadPos) -> Value {
-      Value offset =
-          getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[idx]);
-      offset = getValueOrCreateCastToIndexLike(rewriter, loc,
-                                               rewriter.getI32Type(), offset);
-      Value oldOffset =
-          vector::ExtractOp::create(rewriter, loc, payload, payloadPos);
-      Value newOffset = arith::AddIOp::create(rewriter, loc, oldOffset, offset);
-      return vector::InsertOp::create(rewriter, loc, newOffset, payload,
-                                      payloadPos);
-    };
-    // Update offsets in the payload.
-    payload = updateOffset(0, static_cast<int>(NdTdescOffset::TensorOffsetH));
-    payload = updateOffset(1, static_cast<int>(NdTdescOffset::TensorOffsetW));
-    rewriter.replaceOp(op, payload);
-    return success();
-  }
-};
-
 template <
     typename OpType,
     typename = std::enable_if_t<llvm::is_one_of<
@@ -289,6 +250,10 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
   LogicalResult
   matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    auto mixedOffsets = op.getMixedOffsets();
+    int64_t opOffsetsSize = mixedOffsets.size();
+    if (opOffsetsSize != 2)
+      return rewriter.notifyMatchFailure(op, "Expected 2D offsets.");
     auto loc = op.getLoc();
     auto ctxt = rewriter.getContext();
 
@@ -311,32 +276,14 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
         rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW));
     Value baseShapeH = vector::ExtractOp::create(
         rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH));
-    // Offsets provided in two ways:
-    // 1. Offsets are extracted from the tensor descriptor.
-    // 2. (Mixed) offsets which are provided by the op.
-    Value offsetW;
-    Value offsetH;
-    auto mixedOffsets = op.getMixedOffsets();
-    int64_t opOffsetsSize = mixedOffsets.size();
-    if (opOffsetsSize != 0 && opOffsetsSize != 2)
-      return rewriter.notifyMatchFailure(op,
-                                         "Expected 2D offsets or no offsets.");
-    if (opOffsetsSize) {
-      // If mixed offsets are provided by the op convert them to i32.
-      offsetW = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]);
-      offsetW = getValueOrCreateCastToIndexLike(rewriter, loc,
-                                                rewriter.getI32Type(), offsetW);
-      offsetH = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
-      offsetH = getValueOrCreateCastToIndexLike(rewriter, loc,
-                                                rewriter.getI32Type(), offsetH);
-    } else {
-      // If offsets are not available, we need to extract them from the tensor
-      // descriptor.
-      offsetW = vector::ExtractOp::create(
-          rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::TensorOffsetW));
-      offsetH = vector::ExtractOp::create(
-          rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::TensorOffsetH));
-    }
+    // Offsets are provided by the op.
+    // convert them to i32.
+    Value offsetW = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]);
+    offsetW = getValueOrCreateCastToIndexLike(rewriter, loc,
+                                             rewriter.getI32Type(), offsetW);
+    Value offsetH = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
+    offsetH = getValueOrCreateCastToIndexLike(rewriter, loc,
+                                              rewriter.getI32Type(), offsetH);
     // Get address space from tensor descriptor memory space.
     auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
         ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
@@ -422,54 +369,6 @@ static Value addOffset(ConversionPatternRewriter &rewriter, Location loc,
   return newAddr;
 }
 
-class CreateDescToXeVMPattern
-    : public OpConversionPattern<xegpu::CreateDescOp> {
-  using OpConversionPattern::OpConversionPattern;
-  LogicalResult
-  matchAndRewrite(xegpu::CreateDescOp op, xegpu::CreateDescOp::Adaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    auto eTy = op.getTensorDescType().getElementType();
-    auto eBw = eTy.getIntOrFloatBitWidth();
-    if (eBw % 8 != 0)
-      return rewriter.notifyMatchFailure(
-          op, "Expected element type bit width to be multiple of 8.");
-    auto loc = op.getLoc();
-    // Offsets are provided as scalar i64 by type converter.
-    auto offsets = adaptor.getOffsets();
-    // Source type can be a 1D memref or pointer type (ui64, ui32, i64 or i32).
-    // But type converter will convert them to integer types.
-    Value addr = adaptor.getSource();
-    // ui32 or i32 are passed as i32 so they need to be casted to i64.
-    if (addr.getType() != rewriter.getI64Type())
-      addr = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(), addr);
-    auto laneAddr = addOffset(rewriter, loc, addr, offsets, eBw / 8);
-    rewriter.replaceOp(op, laneAddr);
-    return success();
-  }
-};
-
-class UpdateOffsetToXeVMPattern
-    : public OpConversionPattern<xegpu::UpdateOffsetOp> {
-  using OpConversionPattern::OpConversionPattern;
-  LogicalResult
-  matchAndRewrite(xegpu::UpdateOffsetOp op,
-                  xegpu::UpdateOffsetOp::Adaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    auto eTy = op.getTensorDescType().getElementType();
-    auto eBw = eTy.getIntOrFloatBitWidth();
-    if (eBw % 8 != 0)
-      return rewriter.notifyMatchFailure(
-          op, "Expected element type bit width to be multiple of 8.");
-    auto loc = op.getLoc();
-    // Scatter descriptor is provided as scalar i64 by type converter.
-    // Offsets are provided as scalar i64 by type converter.
-    Value newOffset = addOffset(rewriter, loc, adaptor.getTensorDesc(),
-                                adaptor.getOffsets(), eBw / 8);
-    rewriter.replaceOp(op, newOffset);
-    return success();
-  }
-};
-
 template <typename OpType,
           typename = std::enable_if_t<llvm::is_one_of<
               OpType, xegpu::LoadGatherOp, xegpu::StoreScatterOp>::value>>
@@ -478,6 +377,9 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
   LogicalResult
   matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    Value offsets = adaptor.getOffsets();
+    if (!offsets)
+      return rewriter.notifyMatchFailure(op, "Expected offsets to be provided.");
     auto loc = op.getLoc();
     auto ctxt = rewriter.getContext();
     auto tdescTy = op.getTensorDescType();
@@ -527,21 +429,18 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
       basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
                                           basePtrI64);
     }
-    Value offsets = adaptor.getOffsets();
     Value mask = adaptor.getMask();
-    if (offsets) {
-      if (dyn_cast<VectorType>(offsets.getType())) {
-        // Offset needs be scalar. Single element vector is converted to scalar
-        // by type converter.
-        return rewriter.notifyMatchFailure(op,
-                                           "Expected offsets to be a scalar.");
-      } else {
-        // If offsets are provided, we add them to the base pointer.
-        // Offsets are in number of elements, we need to multiply by
-        // element byte size.
-        basePtrI64 =
-            addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize);
-      }
+    if (dyn_cast<VectorType>(offsets.getType())) {
+      // Offset needs be scalar. Single element vector is converted to scalar
+      // by type converter.
+      return rewriter.notifyMatchFailure(op,
+                                         "Expected offsets to be a scalar.");
+    } else {
+      // If offsets are provided, we add them to the base pointer.
+      // Offsets are in number of elements, we need to multiply by
+      // element byte size.
+      basePtrI64 =
+          addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize);
     }
     // Convert base pointer (i64) to LLVM pointer type.
     Value basePtrLLVM =
@@ -1011,13 +910,12 @@ struct ConvertXeGPUToXeVMPass
 //===----------------------------------------------------------------------===//
 void mlir::populateXeGPUToXeVMConversionPatterns(
     const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
-  patterns.add<CreateNdDescToXeVMPattern, UpdateNdOffsetToXeVMPattern,
+  patterns.add<CreateNdDescToXeVMPattern,
                LoadStorePrefetchNdToXeVMPattern<xegpu::LoadNdOp>,
                LoadStorePrefetchNdToXeVMPattern<xegpu::StoreNdOp>,
                LoadStorePrefetchNdToXeVMPattern<xegpu::PrefetchNdOp>>(
       typeConverter, patterns.getContext());
-  patterns.add<CreateDescToXeVMPattern, UpdateOffsetToXeVMPattern,
-               AtomicRMWToXeVMPattern, PrefetchToXeVMPattern,
+  patterns.add<AtomicRMWToXeVMPattern, PrefetchToXeVMPattern,
                LoadStoreToXeVMPattern<xegpu::LoadGatherOp>,
                LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>(
       typeConverter, patterns.getContext());
diff --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
index ed664a739d134..d6e36fa73bf04 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
@@ -43,38 +43,6 @@ gpu.module @create_nd_tdesc {
         // CHECK: %[[VAR19:.*]] = vector.insert %[[OFFSET_W2]], %[[VAR18]] [4] : i32 into vector<8xi32>
         // CHECK: %[[PAYLOAD:.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32>
         %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
-
-        // CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32>
-        // CHECK: %[[INTPTR_2:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32> -> index
-        // CHECK: %[[OFFSET_W3:.*]] = arith.index_cast %[[ARG7]] : index to i32
-        // CHECK: %[[OFFSET_H3:.*]] = arith.index_cast %[[ARG6]] : index to i32
-        // CHECK: %[[C32_I64_6:.*]] = arith.constant 32 : i64
-        // CHECK: %[[SHAPE_W3:.*]] = arith.trunci %[[C32_I64_6]] : i64 to i32
-        // CHECK: %[[C16_I64_7:.*]] = arith.constant 16 : i64
-        // CHECK: %[[SHAPE_H3:.*]] = arith.trunci %[[C16_I64_7]] : i64 to i32
-        // CHECK: %[[BASE_ADDR3:.*]] = arith.index_castui %[[INTPTR_2]] : index to i64
-        // CHECK: %[[VAR26:.*]] = vector.bitcast %[[CST_4]] : vector<8xi32> to vector<4xi64>
-        // CHECK: %[[VAR27:.*]] = vector.insert %[[BASE_ADDR3]], %[[VAR26]] [0] : i64 into vector<4xi64>
-        // CHECK: %[[VAR28:.*]] = vector.bitcast %[[VAR27]] : vector<4xi64> to vector<8xi32>
-        // CHECK: %[[VAR29:.*]] = vector.insert %[[SHAPE_W3]], %[[VAR28]] [2] : i32 into vector<8xi32>
-        // CHECK: %[[VAR30:.*]] = vector.insert %[[SHAPE_H3]], %[[VAR29]] [3] : i32 into vector<8xi32>
-        // CHECK: %[[VAR31:.*]] = vector.insert %[[OFFSET_W3]], %[[VAR30]] [4] : i32 into vector<8xi32>
-        // CHECK: %[[VAR32:.*]] = vector.insert %[[OFFSET_H3]], %[[VAR31]] [5] : i32 into vector<8xi32>
-        %src_tdesc2 = xegpu.create_nd_tdesc %srcce[%offset1, %offset2] : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
-
-        // CHECK: %[[C8:.*]] = arith.constant 8 : index
-        %c8 = arith.constant 8 : index
-        // CHECK: %[[C16:.*]] = arith.constant 16 : index
-        %c16 = arith.constant 16 : index
-        // CHECK: %[[VAR33:.*]] = arith.index_cast %[[C8]] : index to i32
-        // CHECK: %[[OLD_OFFSET_H:.*]] = vector.extract %[[PAYLOAD]][5] : i32 from vector<8xi32>
-        // CHECK: %[[NEW_OFFSET_H:.*]] = arith.addi %[[OLD_OFFSET_H]], %[[VAR33]] : i32
-        // CHECK: %[[NEW_PAYLOAD:.*]] = vector.insert %[[NEW_OFFSET_H]], %[[PAYLOAD]] [5] : i32 into vector<8xi32>
-        // CHECK: %[[VAR37:.*]] = arith.index_cast %[[C16]] : index to i32
-        // CHECK: %[[OLD_OFFSET_H:.*]] = vector.extract %[[NEW_PAYLOAD]][4] : i32 from vector<8xi32>
-        // CHECK: %[[NEW_OFFSET_H:.*]] = arith.addi %[[OLD_OFFSET_H]], %[[VAR37]] : i32
-        // CHECK: %[[FINAL_PAYLOAD:.*]] = vector.insert %[[NEW_OFFSET_H]], %[[NEW_PAYLOAD]] [4] : i32 into vector<8xi32>
-        %updated_tdesc = xegpu.update_nd_offset %src_tdesc, [%c8, %c16] : !xegpu.tensor_desc<8x16xf32>
         gpu.return
     }
 }
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir
index 0f67dc290689b..0b150e9d58153 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir
@@ -1,239 +1,73 @@
 // RUN: mlir-opt %s --split-input-file -convert-xegpu-to-xevm | FileCheck %s
 
 gpu.module @test {
-// CHECK-LABEL: @load_gather_ui64_src_constant_offset
-// CHECK-SAME: %[[ARG0:.*]]: ui64
-gpu.func @load_gather_ui64_src_constant_offset(%src: ui64) {
-  // CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui64 to index
-  // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
-  // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
-  // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
-  // CHECK: %[[VAR3:.*]] = arith.index_castui %[[VAR2]] : index to i64
-  %0 = arith.constant dense<0> : vector<1xindex>
-  // CHECK: %[[CST_0:.*]] = arith.constant dense<true> : vector<1xi1>
-  // CHECK: %[[VAR4:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1>
-  %1 = arith.constant dense<1>: vector<1xi1>
-  // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
-  // CHECK: %[[VAR5:.*]] = arith.muli %[[VAR3]], %[[C4_I64]] : i64
-  // CHECK: %[[VAR6:.*]] = arith.addi %[[VAR1]], %[[VAR5]] : i64
-  %2 = xegpu.create_tdesc %src, %0 : ui64, vector<1xindex>
-      -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
-  // CHECK: %[[VAR7:.*]] = llvm.inttoptr %[[VAR6]] : i64 to !llvm.ptr<1>
-  // CHECK: %[[VAR8:.*]] = scf.if %[[VAR4]] -> (vector<2xf32>) {
-  // CHECK:      %[[VAR9:.*]] = llvm.load %[[VAR7]] {cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}
-  // CHECK-SAME:     : !llvm.ptr<1> -> vector<2xf32>
-  // CHECK:      scf.yield %[[VAR9]] : vector<2xf32>
-  // CHECK:    } else {
-  // CHECK:      %[[CST_1:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
-  // CHECK:      scf.yield %[[CST_1]] : vector<2xf32>
-  %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
-      : !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<1xi1> -> vector<2xf32>
-  gpu.return
-}
-}
-// -----
-
-gpu.module @test {
-// CHECK-LABEL: @load_gather_memref_src_constant_offset
-// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>
-gpu.func @load_gather_memref_src_constant_offset(%src: memref<256xf32>) {
-  // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
-  // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64
-  // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
-  // CHECK: %[[VAR0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
-  // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
-  %0 = arith.constant dense<0> : vector<1xindex>
-  // CHECK: %[[CST_0:.*]] = arith.constant dense<true> : vector<1xi1>
-  // CHECK: %[[VAR2:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1>
-  %1 = arith.constant dense<1>: vector<1xi1>
-  // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
-  // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64
-  // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR3]], %[[VAR4]] : i64
-  %2 = xegpu.create_tdesc %src, %0 : memref<256xf32>, vector<1xindex>
-      -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
-  // CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1>
-  // CHECK: %[[VAR7:.*]] = scf.if %[[VAR2]] -> (f32) {
-  // CHECK:      %[[VAR8:.*]] = llvm.load %[[VAR6]] {cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}
-  // CHECK-SAME:     : !llvm.ptr<1> -> vector<1xf32>
-  // CHECK:      %[[VAR9:.*]] = vector.extract %[[VAR8]][0] : f32 from vector<1xf32>
-  // CHECK:      scf.yield %[[VAR9]] : f32
-  // CHECK:    } else {
-  // CHECK:      %[[CST_1:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
-  // CHECK:      %[[VAR8:.*]] = vector.extract %[[CST_1:.*]][0] : f32 from vector<1xf32>
-  // CHECK:      scf.yield %[[VAR8]] : f32
-  %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
-      : !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>, vector<1xi1> -> vector<1xf32>
-  gpu.return
-}
-}
-// -----
-
-gpu.module @test {
-// CHECK-LABEL: @load_gather_memref_src_value_offset
-// CHECK-SAME: %[[ARG0:.*]]: memref<256xf16>, %[[ARG1:.*]]: vector<1xindex>
-gpu.func @load_gather_memref_src_value_offset(%src: memref<256xf16>, %offset: vector<1xindex>) {
+// CHECK-LABEL: @load_gather_i64_src_value_offset
+// CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: vector<1xindex>
+gpu.func @load_gather_i64_src_value_offset(%src: i64, %offset: vector<1xindex>) {
   // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex>
   // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
-  // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf16> -> index
-  // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64
   // CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<1xi1>
   // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : i1 from vector<1xi1>
   %1 = arith.constant dense<1>: vector<1xi1>
   // CHECK: %[[C2_I64:.*]] = arith.constant 2 : i64
-  // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C2_I64]] : i64
-  // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR3]], %[[VAR4]] : i64
-  %2 = xegpu.create_tdesc %src, %offset : memref<256xf16>, vector<1xindex>
-      -> !xegpu.tensor_desc<1x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8>>
-  // CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1>
-  // CHECK: %[[VAR7:.*]] = scf.if %[[VAR2]] -> (vector<8xf16>) {
-  // CHECK:      %[[VAR8:.*]] = llvm.load %[[VAR6]] {cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}
-  // CHECK-SAME:     : !llvm.ptr<1> -> vector<8xf16>
-  // CHECK:      scf.yield %[[VAR8]] : vector<8xf16>
-  // CHECK:    } else {
-  // CHECK:      %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<8xf16>
-  // CHECK:      scf.yield %[[CST_0]] : vector<8xf16>
-  %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
-      : !xegpu.tensor_desc<1x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8>>, vector<1xi1> -> vector<8xf16>
-  gpu.return
-}
-}
-// -----
-
-gpu.module @test {
-// CHECK-LABEL: @store_scatter_ui64_src_constant_offset
-// CHECK-SAME: %[[ARG0:.*]]: ui64
-gpu.func @store_scatter_ui64_src_constant_offset(%src: ui64) {
-  // CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui64 to index
-  // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
-  // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
-  // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
-  // CHECK: %[[VAR3:.*]] = arith.index_castui %[[VAR2]] : index to i64
-  %0 = arith.constant dense<0> : vector<1xindex>
-  // CHECK: %[[CST_0:.*]] = arith.constant dense<true> : vector<1xi1>
-  // CHECK: %[[VAR4:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1>
-  %1 = arith.constant dense<1>: vector<1xi1>
-  // CHECK: %[[CST_1:.*]] = arith.constant dense<2.900000e+00> : vector<2xf32>
-  %2 = arith.constant dense<2.9>: vector<2xf32>
-  // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
-  // CHECK: %[[VAR5:.*]] = arith.muli %[[VAR3]], %[[C4_I64]] : i64
-  // CHECK: %[[VAR6:.*]] = arith.addi %[[VAR1]], %[[VAR5]] : i64
-  %3 = xegpu.create_tdesc %src, %0 : ui64, vector<1xindex>
-      -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
-  // CHECK: %[[VAR7:.*]] = llvm.inttoptr %[[VAR6]] : i64 to !llvm.ptr<1>
-  // CHECK:    scf.if %[[VAR4]] {
-  // CHECK:      llvm.store %[[CST_1]], %[[VAR7]] {cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>}
-  // CHECK-SAME:     : vector<2xf32>, !llvm.ptr<1>
-  xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
-      : vector<2xf32>, !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<1xi1>
-  gpu.return
-}
-}
-// -----
-
-gpu.module @test {
-// CHECK-LABEL: @store_scatter_memref_src_constant_offset
-// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>
-gpu.func @store_scatter_memref_src_constant_offset(%src: memref<256xf32>) {
-  // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
-  // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64
-  // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
-  // CHECK: %[[VAR0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
-  // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
-  %0 = arith.constant dense<0> : vector<1xindex>
-  // CHECK: %[[CST_0:.*]] = arith.constant dense<true> : vector<1xi1>
-  // CHECK: %[[VAR2:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1>
-  %1 = arith.constant dense<1>: vector<1xi1>
-  // CHECK: %[[CST_1:.*]] = arith.constant dense<2.900390e+00> : vector<2xf16>
-  %2 = arith.constant dense<2.9>: vector<2xf16>
-  // CHECK: %[[C2_I64:.*]] = arith.constant 2 : i64
-  // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C2_I64]] : i64
-  // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR3]], %[[VAR4]] : i64
-  %3 = xegpu.create_tdesc %src, %0 : memref<256xf32>, vector<1xindex>
-      -> !xegpu.tensor_desc<1x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
-  // CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1>
-  // CHECK: scf.if %[[VAR2]] {
-  // CHECK:      llvm.store %[[CST_1]], %[[VAR6]] {cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>}
-  // CHECK-SAME:     : vector<2xf16>, !llvm.ptr<1>
-  xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
-      : vector<2xf16>, !xegpu.tensor_desc<1x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<1xi1>
+  // CHECK: %[[VAR3:.*]] = arith.muli %[[VAR1]], %[[C2_I64]] : i64
+  // CHECK: %[[VAR4:.*]] = arith.addi %[[ARG0]], %[[VAR3]] : i64
+  // CHECK: %[[VAR5:.*]] = llvm.inttoptr %[[VAR4]] : i64 to !llvm.ptr<1>
+  // CHECK: %[[VAR6:.*]] = scf.if %[[VAR2]] -> (f16) {
+  // CHECK:   %[[VAR7:.*]] = llvm.load %[[VAR5]] {cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>} : !llvm.ptr<1> -> vector<1xf16>
+  // CHECK:   %[[VAR8:.*]] = vector.extract %[[VAR7]][0] : f16 from vector<1xf16>
+  // CHECK:   scf.yield %[[VAR8]] : f16
+  // CHECK: } else {
+  // CHECK:   %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<1xf16>
+  // CHECK:   %[[VAR7:.*]] = vector.extract %[[CST_0]][0] : f16 from vector<1xf16>
+  // CHECK:   scf.yield %[[VAR7]] : f16
+  // CHECK: }
+  %3 = xegpu.load %src[%offset], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+      : i64, vector<1xindex>, vector<1xi1> -> vector<1xf16>
   gpu.return
 }
 }
 // -----
 
 gpu.module @test {
-// CHECK-LABEL: @store_scatter_memref_src_value_offset
-// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>, %[[ARG1:.*]]: vector<1xindex>
-gpu.func @store_scatter_memref_src_value_offset(%src: memref<256xf32>, %offset: vector<1xindex>) {
+// CHECK-LABEL: @store_scatter_i64_src_value_offset
+// CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: vector<1xindex>
+gpu.func @store_scatter_i64_src_value_offset(%src: i64, %offset: vector<1xindex>) {
   // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex>
   // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
-  // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
-  // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64
   // CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<1xi1>
   // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : i1 from vector<1xi1>
   %1 = arith.constant dense<1>: vector<1xi1>
   // CHECK: %[[CST_0:.*]] = arith.constant dense<2.900000e+00> : vector<1xf32>
-  // CHECK: %[[VAR7:.*]] = vector.extract %[[CST_0]][0] : f32 from vector<1xf32>
+  // CHECK: %[[VAR3:.*]] = vector.extract %[[CST_0]][0] : f32 from vector<1xf32>
   %2 = arith.constant dense<2.9>: vector<1xf32>
   // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
   // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64
-  // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR3]], %[[VAR4]] : i64
-  %3 = xegpu.create_tdesc %src, %offset : memref<256xf32>, vector<1xindex>
-      -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
+  // CHECK: %[[VAR5:.*]] = arith.addi %[[ARG0]], %[[VAR4]] : i64
   // CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1>
   // CHECK: scf.if %[[VAR2]] {
-  // CHECK:      llvm.store %[[VAR7]], %[[VAR6]] {cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>}
-  // CHECK-SAME:     : f32, !llvm.ptr<1>
-  xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
-      : vector<1xf32>, !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>, vector<1xi1>
+  // CHECK:   llvm.store %[[VAR3]], %[[VAR6]] {cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>} : f32, !llvm.ptr<1>
+  // CHECK: }
+  xegpu.store %2, %src[%offset], %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
+      : vector<1xf32>, i64, vector<1xindex>, vector<1xi1>
   gpu.return
 }
 }
 // -----
 
 gpu.module @test {
-// CHECK-LABEL: @prefetch_ui64_src_constant_offset
-// CHECK-SAME: %[[ARG0:.*]]: ui64
-gpu.func @prefetch_ui64_src_constant_offset(%src: ui64) {
-  // CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui64 to index
-  // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
-  // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
-  // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
-  // CHECK: %[[VAR3:.*]] = arith.index_castui %[[VAR2]] : index to i64
-  %0 = arith.constant dense<0> : vector<1xindex>
-  // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
-  // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR3]], %[[C4_I64]] : i64
-  // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR1]], %[[VAR4]] : i64
-  %1 = xegpu.create_tdesc %src, %0 : ui64, vector<1xindex>
-      -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
-  // CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1>
-  // CHECK: xevm.prefetch %[[VAR6]] <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}> : (!llvm.ptr<1>)
-  xegpu.prefetch %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
-      : !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
-  gpu.return
-}
-}
-// -----
-
-gpu.module @test {
-// CHECK-LABEL: @prefetch_memref_src_constant_offset
-// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>
-gpu.func @prefetch_memref_src_constant_offset(%src: memref<256xf32>) {
-  // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
-  // CHECK: %[[VAR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
-  // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
-  // CHECK: %[[VAR0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
+// CHECK-LABEL: @prefetch_i64_src_value_offset
+// CHECK-SAME: %[[ARG0:.*]]: i64, %[[ARG1:.*]]: vector<1xindex>
+gpu.func @prefetch_i64_src_value_offset(%src: i64, %offset: vector<1xindex>) {
+  // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex>
   // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
-  %0 = arith.constant dense<0> : vector<1xindex>
   // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
-  // CHECK: %[[VAR3:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64
-  // CHECK: %[[VAR4:.*]] = arith.addi %[[VAR2]], %[[VAR3]] : i64
-  %1 = xegpu.create_tdesc %src, %0 : memref<256xf32>, vector<1xindex>
-      -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
-  // CHECK: %[[VAR5:.*]] = llvm.inttoptr %[[VAR4]] : i64 to !llvm.ptr<1>
-  // CHECK: xevm.prefetch %[[VAR5]] <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}> : (!llvm.ptr<1>)
-  xegpu.prefetch %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
-      : !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+  // CHECK: %[[VAR2:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64
+  // CHECK: %[[VAR3:.*]] = arith.addi %[[ARG0]], %[[VAR2]] : i64
+  // CHECK: %[[VAR4:.*]] = llvm.inttoptr %[[VAR3]] : i64 to !llvm.ptr<1>
+  // CHECK: xevm.prefetch %[[VAR4]] <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}> : (!llvm.ptr<1>)
+  xegpu.prefetch %src[%offset] <{offset_align_byte=4, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+      : i64, vector<1xindex>
   gpu.return
 }
 }
@@ -250,12 +84,10 @@ gpu.func @prefetch_memref_src_value_offset(%src: memref<256xf32>, %offset: vecto
   // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
   // CHECK: %[[VAR3:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64
   // CHECK: %[[VAR4:.*]] = arith.addi %[[VAR2]], %[[VAR3]] : i64
-  %1 = xegpu.create_tdesc %src, %offset : memref<256xf32>, vector<1xindex>
-      -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
   // CHECK: %[[VAR5:.*]] = llvm.inttoptr %[[VAR4]] : i64 to !llvm.ptr<1>
   // CHECK: xevm.prefetch %[[VAR5]] <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}> : (!llvm.ptr<1>)
-  xegpu.prefetch %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
-      : !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+  xegpu.prefetch %src[%offset] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+      : memref<256xf32>, vector<1xindex>
   gpu.return
 }
 }
diff --git a/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir b/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir
index b28a8c2ccf843..2a2b99f57cabd 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir
@@ -9,9 +9,9 @@ gpu.module @materializecast {
   gpu.func @materialize_memref(%src: memref<128xf32>) kernel {
     // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<128xf32> -> index
     // CHECK: %[[CASTED:.*]] = arith.index_castui %[[INTPTR]] : index to i64
-    %offset = arith.constant dense<0> : vector<1xindex>
-    %src_tdesc = xegpu.create_tdesc %src, %offset : memref<128xf32>, vector<1xindex>
-        -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
+    %offset = arith.constant 0 : index
+    %mask = arith.constant 1 : i1
+    %val = xegpu.load %src[%offset], %mask : memref<128xf32>, index, i1 -> f32
     gpu.return
   }
 }
@@ -23,9 +23,9 @@ gpu.module @materializecast {
   gpu.func @materialize_ui64(%src: ui64) kernel {
     // CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui64 to index
     // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
-    %offset = arith.constant dense<0> : vector<1xindex>
-    %src_tdesc = xegpu.create_tdesc %src, %offset : ui64, vector<1xindex>
-        -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
+    %offset = arith.constant 0 : index
+    %mask = arith.constant 1 : i1
+    %val = xegpu.load %src[%offset], %mask : ui64, index, i1 -> vector<1xf32>
     gpu.return
   }
 }
@@ -37,9 +37,9 @@ gpu.module @materializecast {
   gpu.func @materialize_ui32(%src: ui32) kernel {
     // CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui32 to index
     // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i32
-    %offset = arith.constant dense<0> : vector<1xindex>
-    %src_tdesc = xegpu.create_tdesc %src, %offset : ui32, vector<1xindex>
-        -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
+    %offset = arith.constant 0 : index
+    %mask = arith.constant 1 : i1
+    %val = xegpu.load %src[%offset], %mask : ui32, index, i1 -> vector<1xf32>
     gpu.return
   }
 }
@@ -52,24 +52,12 @@ gpu.module @materializecast {
     // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
     // CHECK: %[[VAR1:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
     // CHECK: %[[VAR2:.*]] = arith.index_castui %[[VAR1]] : index to i64
+    // CHECK: %[[CST1:.*]] = arith.constant dense<true> : vector<1xi1>
+    // CHECK: %[[VAR3:.*]] = vector.extract %[[CST1]][0] : i1 from vector<1xi1>
     %offset = arith.constant dense<0> : vector<1xindex>
-    %src_tdesc = xegpu.create_tdesc %src, %offset : memref<128xf32>, vector<1xindex>
-        -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
+    %mask = arith.constant dense<1> : vector<1xi1>
+    %val = xegpu.load %src[%offset], %mask : memref<128xf32>, vector<1xindex>, vector<1xi1> -> vector<1xf32>
     gpu.return
   }
 }
 
-// -----
-gpu.module @materializecast {
-  // CHECK-LABEL: gpu.func @materialize_single_elem_vector
-  // CHECK-SAME: %[[ARG0:.*]]: memref<128xf32>
-  gpu.func @materialize_single_elem_vector(%src: memref<128xf32>) kernel {
-    // CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<1xi1>
-    // CHECK: %[[VAR1:.*]] = vector.extract %[[CST]][0] : i1 from vector<1xi1>
-    %mask = arith.constant dense<1>: vector<1xi1>
-    %offset = arith.constant dense<0> : vector<1xindex>
-    %0 = xegpu.load %src[%offset], %mask <{chunk_size=8, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
-      : memref<128xf32>, vector<1xindex>, vector<1xi1> -> vector<8xf32>
-    gpu.return
-  }
-}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/update_offset.mlir b/mlir/test/Conversion/XeGPUToXeVM/update_offset.mlir
deleted file mode 100644
index 6e59414c62582..0000000000000
--- a/mlir/test/Conversion/XeGPUToXeVM/update_offset.mlir
+++ /dev/null
@@ -1,25 +0,0 @@
-// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
-
-gpu.module @update_offset {
-  // CHECK-LABEL: gpu.func @update_offset
-  // CHECK-SAME: %[[ARG0:.*]]: memref<128xf32>
-  gpu.func @update_offset(%src: memref<128xf32>) kernel {
-    // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<128xf32> -> index
-    // CHECK: %[[VAR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
-    // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
-    %offset = arith.constant dense<0> : vector<1xindex>
-    // CHECK: %[[VAR0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
-    // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
-    // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
-    // CHECK: %[[VAR3:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64
-    // CHECK: %[[VAR4:.*]] = arith.addi %[[VAR2]], %[[VAR3]] : i64
-    %src_tdesc = xegpu.create_tdesc %src, %offset : memref<128xf32>, vector<1xindex>
-        -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
-    // CHECK: %[[C4_I64_0:.*]] = arith.constant 4 : i64
-    // CHECK: %[[VAR5:.*]] = arith.muli %[[VAR1]], %[[C4_I64_0]] : i64
-    // CHECK: %[[VAR6:.*]] = arith.addi %[[VAR4]], %[[VAR5]] : i64
-    %new_tdesc = xegpu.update_offset %src_tdesc, %offset : !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
-        , vector<1xindex>
-    gpu.return
-  }
-}



More information about the Mlir-commits mailing list