[Mlir-commits] [mlir] [MLIR][Conversion][XeGPU][XeVM] Add XeGPUToXeVM conversion pass and tests. (PR #154556)

Sang Ik Lee llvmlistbot at llvm.org
Mon Aug 25 17:50:08 PDT 2025


https://github.com/silee2 updated https://github.com/llvm/llvm-project/pull/154556

>From b860103f876016074f35dd6f6401c6b58d31f6ef Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Mon, 14 Jul 2025 18:54:41 +0000
Subject: [PATCH 01/10] Add XeGPUToXeVM conversion pass and tests.

---
 mlir/include/mlir/Conversion/Passes.h         |   1 +
 mlir/include/mlir/Conversion/Passes.td        |  12 +
 .../mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h |  27 +
 mlir/lib/Conversion/CMakeLists.txt            |   1 +
 .../lib/Conversion/XeGPUToXeVM/CMakeLists.txt |  25 +
 .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp    | 932 ++++++++++++++++++
 .../XeGPUToXeVM/create_nd_tdesc.mlir          |  48 +
 mlir/test/Conversion/XeGPUToXeVM/dpas.mlir    |  17 +
 mlir/test/Conversion/XeGPUToXeVM/fence.mlir   |  15 +
 .../Conversion/XeGPUToXeVM/loadstore_nd.mlir  |  71 ++
 .../XeGPUToXeVM/loadstoreprefetch.mlir        | 357 +++++++
 .../Conversion/XeGPUToXeVM/prefetch_nd.mlir   |  40 +
 .../Conversion/XeGPUToXeVM/update_offset.mlir |  25 +
 13 files changed, 1571 insertions(+)
 create mode 100644 mlir/include/mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h
 create mode 100644 mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt
 create mode 100644 mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
 create mode 100644 mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
 create mode 100644 mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
 create mode 100644 mlir/test/Conversion/XeGPUToXeVM/fence.mlir
 create mode 100644 mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
 create mode 100644 mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir
 create mode 100644 mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir
 create mode 100644 mlir/test/Conversion/XeGPUToXeVM/update_offset.mlir

diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 91b2ecf8922a3..da061b269daf7 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -82,6 +82,7 @@
 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
 #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRVPass.h"
 #include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h"
+#include "mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h"
 #include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h"
 
 namespace mlir {
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 2058aba7f9e37..323af3e97e2d4 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1555,4 +1555,16 @@ def ConvertXeVMToLLVMPass : Pass<"convert-xevm-to-llvm"> {
   let dependentDialects = ["LLVM::LLVMDialect"];
 }
 
+//===----------------------------------------------------------------------===//
+// XeGPUToXeVM
+//===----------------------------------------------------------------------===//
+
+def ConvertXeGPUToXeVMPass : Pass<"convert-xegpu-to-xevm"> {
+  let summary = "Convert XeGPU to XeVM dialect";
+  let dependentDialects = ["xevm::XeVMDialect", "vector::VectorDialect",
+                           "memref::MemRefDialect", "arith::ArithDialect",
+                           "LLVM::LLVMDialect", "index::IndexDialect",
+                           "gpu::GPUDialect", "scf::SCFDialect"];
+}
+
 #endif // MLIR_CONVERSION_PASSES
diff --git a/mlir/include/mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h b/mlir/include/mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h
new file mode 100644
index 0000000000000..fb23d24b0161b
--- /dev/null
+++ b/mlir/include/mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h
@@ -0,0 +1,27 @@
+//===-- XeGPUToXeVM.h - Convert XeGPU to XeVM dialect ---------_--*- C++-*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_XEGPUTOXEVM_XEGPUTOXEVMPASS_H_
+#define MLIR_CONVERSION_XEGPUTOXEVM_XEGPUTOXEVMPASS_H_
+
+#include <memory>
+
+namespace mlir {
+class DialectRegistry;
+class LLVMTypeConverter;
+class RewritePatternSet;
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTXEGPUTOXEVMPASS
+#include "mlir/Conversion/Passes.h.inc"
+
+void populateXeGPUToXeVMConversionPatterns(
+    mlir::RewritePatternSet &patterns, mlir::LLVMTypeConverter &typeConverter);
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_XEGPUTOXEVM_XEGPUTOXEVMPASS_H_
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 171f7169fd41d..134fe8e14ca38 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -76,3 +76,4 @@ add_subdirectory(VectorToSCF)
 add_subdirectory(VectorToSPIRV)
 add_subdirectory(VectorToXeGPU)
 add_subdirectory(XeVMToLLVM)
+add_subdirectory(XeGPUToXeVM)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt b/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt
new file mode 100644
index 0000000000000..ed54b0bb5ee81
--- /dev/null
+++ b/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt
@@ -0,0 +1,25 @@
+add_mlir_conversion_library(MLIRXeGPUToXeVM
+  XeGPUToXeVM.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/XeGPUToXeVM
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRFuncDialect
+  MLIRGPUDialect
+  MLIRLLVMCommonConversion
+  MLIRLLVMDialect
+  MLIRXeVMDialect
+  MLIRVectorDialect
+  MLIRArithDialect
+  MLIRIndexDialect
+  MLIRXeGPUDialect
+  MLIRPass
+  MLIRTransforms
+)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
new file mode 100644
index 0000000000000..380409afbc62e
--- /dev/null
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -0,0 +1,932 @@
+//===-- XeVMToLLVM.cpp - XeVM to LLVM dialect conversion --------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
+
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/Support/FormatVariadic.h"
+
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Types.h"
+
+#include "llvm/ADT/TypeSwitch.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTXEGPUTOXEVMPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+enum class NdDescI32Layout : uint32_t {
+  BasePtr = 0,
+  BaseShapeW = 2,
+  BaseShapeH = 3,
+  TensorOffsetW = 4,
+  TensorOffsetH = 5
+};
+
+static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
+  switch (xeGpuMemspace) {
+  case xegpu::MemorySpace::Global:
+    return static_cast<int>(xevm::AddrSpace::GLOBAL);
+  case xegpu::MemorySpace::SLM:
+    return static_cast<int>(xevm::AddrSpace::SHARED);
+  }
+  llvm_unreachable("Unknown XeGPU memory space.");
+}
+
+template <typename T>
+std::tuple<bool, int32_t, int32_t> checkAllLinear(SmallVector<T> denseAttr) {
+  assert(!denseAttr.empty());
+  const int32_t intercept{static_cast<int32_t>(denseAttr[0])};
+  if (denseAttr.size() < 2)
+    return {true, 0, intercept};
+  const T slope{denseAttr[1] - denseAttr[0]};
+  for (size_t i = 1; i < denseAttr.size(); ++i)
+    if (denseAttr[i] - denseAttr[i - 1] != slope)
+      return {false, 0, 0};
+  return {true, static_cast<int32_t>(slope), intercept};
+}
+
+VectorType encodeVectorTypeTo(VectorType currentVecType, Type toElemType) {
+  auto elemType = currentVecType.getElementType();
+  auto currentBitWidth = elemType.getIntOrFloatBitWidth();
+  auto newBitWidth = toElemType.getIntOrFloatBitWidth();
+  const int size =
+      currentVecType.getNumElements() * currentBitWidth / newBitWidth;
+  return VectorType::get(size, toElemType);
+}
+
+xevm::LoadCacheControl
+translateLoadXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
+                            std::optional<xegpu::CachePolicy> L3hint) {
+  auto L1hintVal =
+      L1hint.has_value() ? L1hint.value() : xegpu::CachePolicy::UNCACHED;
+  auto L3hintVal =
+      L3hint.has_value() ? L3hint.value() : xegpu::CachePolicy::UNCACHED;
+  switch (L1hintVal) {
+  case xegpu::CachePolicy::CACHED:
+    if (L3hintVal == xegpu::CachePolicy::CACHED)
+      return xevm::LoadCacheControl::L1C_L2UC_L3C;
+    else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+      return xevm::LoadCacheControl::L1C_L2UC_L3UC;
+    else
+      llvm_unreachable("Unsupported cache control.");
+  case xegpu::CachePolicy::UNCACHED:
+    if (L3hintVal == xegpu::CachePolicy::CACHED)
+      return xevm::LoadCacheControl::L1UC_L2UC_L3C;
+    else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+      return xevm::LoadCacheControl::L1UC_L2UC_L3UC;
+    else
+      llvm_unreachable("Unsupported cache control.");
+  case xegpu::CachePolicy::STREAMING:
+    if (L3hintVal == xegpu::CachePolicy::CACHED)
+      return xevm::LoadCacheControl::L1S_L2UC_L3C;
+    else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+      return xevm::LoadCacheControl::L1S_L2UC_L3UC;
+    else
+      llvm_unreachable("Unsupported cache control.");
+  case xegpu::CachePolicy::READ_INVALIDATE:
+    return xevm::LoadCacheControl::INVALIDATE_READ;
+  default:
+    llvm_unreachable("Unsupported cache control.");
+  }
+}
+
+xevm::StoreCacheControl
+translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
+                             std::optional<xegpu::CachePolicy> L3hint) {
+  auto L1hintVal =
+      L1hint.has_value() ? L1hint.value() : xegpu::CachePolicy::UNCACHED;
+  auto L3hintVal =
+      L3hint.has_value() ? L3hint.value() : xegpu::CachePolicy::UNCACHED;
+  switch (L1hintVal) {
+  case xegpu::CachePolicy::UNCACHED:
+    if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+      return xevm::StoreCacheControl::L1UC_L2UC_L3UC;
+    else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
+      return xevm::StoreCacheControl::L1UC_L2UC_L3WB;
+    else
+      llvm_unreachable("Unsupported cache control.");
+  case xegpu::CachePolicy::STREAMING:
+    if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+      return xevm::StoreCacheControl::L1S_L2UC_L3UC;
+    else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
+      return xevm::StoreCacheControl::L1S_L2UC_L3WB;
+    else
+      llvm_unreachable("Unsupported cache control.");
+  case xegpu::CachePolicy::WRITE_BACK:
+    if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+      return xevm::StoreCacheControl::L1WB_L2UC_L3UC;
+    else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
+      return xevm::StoreCacheControl::L1WB_L2UC_L3WB;
+    else
+      llvm_unreachable("Unsupported cache control.");
+  case xegpu::CachePolicy::WRITE_THROUGH:
+    if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+      return xevm::StoreCacheControl::L1WT_L2UC_L3UC;
+    else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
+      return xevm::StoreCacheControl::L1WT_L2UC_L3WB;
+    else
+      llvm_unreachable("Unsupported cache control.");
+  default:
+    llvm_unreachable("Unsupported cache control.");
+  }
+}
+
+class CreateNdDescToXeVMPattern
+    : public OpConversionPattern<xegpu::CreateNdDescOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::CreateNdDescOp op,
+                  xegpu::CreateNdDescOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto source = op.getSource();
+    Type payloadElemTy = rewriter.getI32Type();
+    Type i64Ty = rewriter.getI64Type();
+    VectorType payloadTy = VectorType::get(8, payloadElemTy);
+    VectorType payloadI64Ty = VectorType::get(4, i64Ty);
+    Value payload = arith::ConstantOp::create(
+        rewriter, loc,
+        DenseElementsAttr::get(payloadTy, IntegerAttr::get(payloadElemTy, 0)));
+
+    Value baseAddr;
+    Value baseShapeW;
+    Value baseShapeH;
+    Value offsetW;
+    Value offsetH;
+
+    bool sourceIsMemref = false;
+    auto sourceTy = source.getType();
+    int64_t rank;
+    if (isa<MemRefType>(sourceTy)) {
+      sourceIsMemref = true;
+      baseAddr =
+          memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source);
+      auto sourceMemrefTy = cast<MemRefType>(sourceTy);
+      if (!sourceMemrefTy.hasStaticShape()) {
+        op.emitError() << "Expected static memref shape.";
+        return failure();
+      }
+      rank = sourceMemrefTy.getRank();
+      if (rank != 2) {
+        op.emitError() << "Expected a 2D memref.";
+        return failure();
+      }
+    } else if (sourceTy == rewriter.getIntegerType(64, false)) {
+      rank = op.getMixedSizes().size();
+    } else {
+      op.emitError() << "Expected source to be a 2D memref or ui64.";
+      return failure();
+    }
+    auto createOffset = [&](unsigned idx) -> Value {
+      Value val;
+      OpFoldResult ofr = op.getMixedOffsets()[idx];
+      if (auto v = llvm::dyn_cast_if_present<Value>(ofr)) {
+        val = arith::IndexCastOp::create(rewriter, loc, i64Ty, v);
+        val = arith::TruncIOp::create(rewriter, loc, payloadElemTy, val);
+      } else {
+        int32_t off = llvm::cast<IntegerAttr>(cast<Attribute>(ofr)).getInt();
+        val = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, off);
+      }
+      return val;
+    };
+    auto offsets = op.getMixedOffsets();
+    if (offsets.size() == 2) {
+      offsetW = createOffset(rank - 1);
+      offsetH = createOffset(rank - 2);
+    } else {
+      offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
+      offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
+    }
+    auto createShape = [&](unsigned idx) -> Value {
+      Value val;
+      OpFoldResult ofr = op.getMixedSizes()[idx];
+      if (auto v = llvm::dyn_cast_if_present<Value>(ofr)) {
+        val = arith::IndexCastOp::create(rewriter, loc, i64Ty, v);
+        val = arith::TruncIOp::create(rewriter, loc, payloadElemTy, val);
+      } else {
+        int32_t off = llvm::cast<IntegerAttr>(cast<Attribute>(ofr)).getInt();
+        val = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, off);
+      }
+      return val;
+    };
+    if (sourceIsMemref) {
+      auto sourceMemrefTy = cast<MemRefType>(sourceTy);
+      baseShapeW = arith::ConstantIntOp::create(
+          rewriter, loc, payloadElemTy, sourceMemrefTy.getDimSize(rank - 1));
+      baseShapeH = arith::ConstantIntOp::create(
+          rewriter, loc, payloadElemTy, sourceMemrefTy.getDimSize(rank - 2));
+      baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr);
+    } else {
+      baseShapeW = createShape(rank - 1);
+      baseShapeH = createShape(rank - 2);
+      baseAddr = adaptor.getSource();
+    }
+    Value payLoadAsI64 =
+        vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
+    payLoadAsI64 =
+        vector::InsertOp::create(rewriter, loc, baseAddr, payLoadAsI64,
+                                 static_cast<int>(NdDescI32Layout::BasePtr));
+    payload = vector::BitCastOp::create(rewriter, loc, payloadTy, payLoadAsI64);
+    payload =
+        vector::InsertOp::create(rewriter, loc, baseShapeW, payload,
+                                 static_cast<int>(NdDescI32Layout::BaseShapeW));
+    payload =
+        vector::InsertOp::create(rewriter, loc, baseShapeH, payload,
+                                 static_cast<int>(NdDescI32Layout::BaseShapeH));
+    payload = vector::InsertOp::create(
+        rewriter, loc, offsetW, payload,
+        static_cast<int>(NdDescI32Layout::TensorOffsetW));
+    payload = vector::InsertOp::create(
+        rewriter, loc, offsetH, payload,
+        static_cast<int>(NdDescI32Layout::TensorOffsetH));
+    rewriter.replaceOp(op, payload);
+    return success();
+  }
+};
+
+class UpdateNdOffsetToXeVMPattern
+    : public OpConversionPattern<xegpu::UpdateNdOffsetOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::UpdateNdOffsetOp op,
+                  xegpu::UpdateNdOffsetOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto offsets = op.getOffsets();
+    auto tdesc = adaptor.getTensorDesc();
+    for (size_t offsetDim = 0; offsetDim < offsets.size(); offsetDim++) {
+      auto offset = offsets[offsetDim];
+      if (auto cst =
+              dyn_cast_if_present<arith::ConstantOp>(offset.getDefiningOp()))
+        if (auto attr = dyn_cast_if_present<IntegerAttr>(cst.getValue());
+            attr && !attr.getInt())
+          continue;
+      const int offsetPos =
+          static_cast<int>(offsetDim ? NdDescI32Layout::TensorOffsetW
+                                     : NdDescI32Layout::TensorOffsetH);
+      auto oldOffset =
+          vector::ExtractOp::create(rewriter, loc, tdesc, offsetPos);
+      offset = arith::IndexCastUIOp::create(rewriter, loc,
+                                            rewriter.getI32Type(), offset);
+      auto newOffset = arith::AddIOp::create(rewriter, loc, oldOffset, offset);
+      tdesc =
+          vector::InsertOp::create(rewriter, loc, newOffset, tdesc, offsetPos);
+    }
+    rewriter.replaceOp(op, tdesc);
+    return success();
+  }
+};
+
+template <
+    typename OpType,
+    typename = std::enable_if_t<llvm::is_one_of<
+        OpType, xegpu::LoadNdOp, xegpu::StoreNdOp, xegpu::PrefetchNdOp>::value>>
+class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
+  using OpConversionPattern<OpType>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto ctxt = rewriter.getContext();
+
+    auto tdesc = adaptor.getTensorDesc();
+    auto tdescTy = op.getTensorDescType();
+    if (tdescTy.getRank() != 2) {
+      return rewriter.notifyMatchFailure(op, "Expected 2D tensor descriptor.");
+    }
+
+    VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
+    Value payLoadAsI64 =
+        vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc);
+    Value basePtr =
+        vector::ExtractOp::create(rewriter, loc, payLoadAsI64,
+                                  static_cast<int>(NdDescI32Layout::BasePtr));
+    Value baseShapeW = vector::ExtractOp::create(
+        rewriter, loc, tdesc, static_cast<int>(NdDescI32Layout::BaseShapeW));
+    Value baseShapeH = vector::ExtractOp::create(
+        rewriter, loc, tdesc, static_cast<int>(NdDescI32Layout::BaseShapeH));
+    // Offsets can come from three sources:
+    // 1. Constant offsets, which are provided by the op.
+    // 2. Offsets as operands, which are provided by the op.
+    // 3. Offsets extracted from the tensor descriptor.
+    Value offsetW;
+    Value offsetH;
+    auto cOffsets = op.getConstOffsets();
+    auto offsets = op.getOffsets();
+    if (cOffsets) {
+      offsetW = arith::ConstantIntOp::create(
+          rewriter, loc, rewriter.getI32Type(), (*cOffsets)[0]);
+      offsetH = arith::ConstantIntOp::create(
+          rewriter, loc, rewriter.getI32Type(), (*cOffsets)[1]);
+    } else if (offsets.size() != 0) {
+      // offsets are provided as operands
+      if (offsets[0].getType() != rewriter.getI32Type()) {
+        if (offsets[0].getType() != rewriter.getIndexType()) {
+          return rewriter.notifyMatchFailure(
+              op, "Expected offsets to be of type i32 or index.");
+        }
+        offsetW = arith::IndexCastUIOp::create(
+            rewriter, loc, rewriter.getI32Type(), offsets[0]);
+      } else {
+        offsetW = offsets[0];
+      }
+      if (offsets[1].getType() != rewriter.getI32Type()) {
+        if (offsets[1].getType() != rewriter.getIndexType()) {
+          return rewriter.notifyMatchFailure(
+              op, "Expected offsets to be of type i32 or index.");
+        }
+        offsetH = arith::IndexCastUIOp::create(
+            rewriter, loc, rewriter.getI32Type(), offsets[1]);
+      } else {
+        offsetH = offsets[1];
+      }
+    } else {
+      // If offsets are not available, we need to extract them from the tensor
+      // descriptor.
+      offsetW = vector::ExtractOp::create(
+          rewriter, loc, tdesc,
+          static_cast<int>(NdDescI32Layout::TensorOffsetW));
+      offsetH = vector::ExtractOp::create(
+          rewriter, loc, tdesc,
+          static_cast<int>(NdDescI32Layout::TensorOffsetH));
+    }
+    auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
+        ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
+    Value basePtrLLVM =
+        LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
+    auto elemType = tdescTy.getElementType();
+    auto elemBitSize = elemType.getIntOrFloatBitWidth();
+    // auto elemBitSizeAttr = rewriter.getIntegerAttr(rewriter.getI32Type(),
+    // elemBitSize);
+    Value elemByteSize = arith::ConstantIntOp::create(
+        rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
+    Value surfaceW =
+        arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
+
+    auto tileW = tdescTy.getDimSize(1);
+    auto tileH = tdescTy.getDimSize(0);
+    int32_t vblocks = tdescTy.getArrayLength();
+    if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
+      VectorType srcVecTy = cast<VectorType>(op.getValue().getType());
+      auto storeCacheControl =
+          translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
+      VectorType srcFlatVecTy =
+          VectorType::get(srcVecTy.getNumElements(), srcVecTy.getElementType());
+      Value srcFlatVec = op.getValue();
+      srcFlatVecTy = encodeVectorTypeTo(srcFlatVecTy,
+                                        rewriter.getIntegerType(elemBitSize));
+      srcFlatVec =
+          vector::BitCastOp::create(rewriter, loc, srcFlatVecTy, srcFlatVec);
+      xevm::BlockStore2dOp::create(
+          rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
+          offsetH, elemBitSize, tileW, tileH, srcFlatVec,
+          xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
+      rewriter.eraseOp(op);
+    } else {
+      auto loadCacheControl =
+          translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
+      if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
+        xevm::BlockPrefetch2dOp::create(
+            rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
+            offsetH, elemBitSize, tileW, tileH, vblocks,
+            xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
+        rewriter.eraseOp(op);
+      } else {
+        VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
+        const bool vnni = op.getPacked().value_or(false);
+        auto transposeValue = op.getTranspose();
+        bool transpose =
+            transposeValue.has_value() && transposeValue.value()[0] == 1;
+        VectorType loadedTy = encodeVectorTypeTo(
+            dstVecTy, vnni ? rewriter.getI32Type()
+                           : rewriter.getIntegerType(elemBitSize));
+
+        Value resultFlatVec = xevm::BlockLoad2dOp::create(
+            rewriter, loc, loadedTy, basePtrLLVM, surfaceW, baseShapeH,
+            surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
+            transpose, vnni,
+            xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
+        resultFlatVec = vector::BitCastOp::create(
+            rewriter, loc,
+            encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()),
+            resultFlatVec);
+        rewriter.replaceOp(op, resultFlatVec);
+      }
+    }
+    return success();
+  }
+};
+
+template <
+    typename OpType,
+    typename = std::enable_if_t<llvm::is_one_of<
+        OpType, xegpu::LoadGatherOp, xegpu::StoreScatterOp, xegpu::CreateDescOp,
+        xegpu::UpdateOffsetOp, xegpu::PrefetchOp>::value>>
+int64_t getElemByteSize(OpType op) {
+  // Get the element byte size from the tensor descriptor.
+  auto elemBitWidth =
+      op.getTensorDesc().getType().getElementType().getIntOrFloatBitWidth();
+  return elemBitWidth / 8;
+}
+
+// Add a builder that creates
+// offset * elemByteSize + baseAddr
+auto addOffset = [](ConversionPatternRewriter &rewriter, Location loc,
+                    Value baseAddr, Value offset,
+                    int64_t elemByteSize) -> Value {
+  Value byteSize = arith::ConstantIntOp::create(
+      rewriter, loc, rewriter.getI64Type(), elemByteSize);
+  Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize);
+  Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset);
+  return newAddr;
+};
+
+class CreateDescToXeVMPattern
+    : public OpConversionPattern<xegpu::CreateDescOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::CreateDescOp op, xegpu::CreateDescOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto offsets = adaptor.getOffsets();
+    // Source type can be a 1D memref or ui64
+    // Using "op" instead of "adaptor" since we want to access memref type
+    // instead of LLVM struct type.
+    auto memrefTy = dyn_cast<MemRefType>(op.getSource().getType());
+    Value subGroupAddr;
+    if (memrefTy) {
+      subGroupAddr = memref::ExtractAlignedPointerAsIndexOp::create(
+          rewriter, loc, op.getSource());
+      subGroupAddr = arith::IndexCastUIOp::create(
+          rewriter, loc, rewriter.getI64Type(), subGroupAddr);
+    } else {
+      subGroupAddr = adaptor.getSource();
+    }
+    auto laneAddr =
+        addOffset(rewriter, loc, subGroupAddr, offsets, getElemByteSize(op));
+    rewriter.replaceOp(op, laneAddr);
+    return success();
+  }
+};
+
+class UpdateOffsetToXeVMPattern
+    : public OpConversionPattern<xegpu::UpdateOffsetOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::UpdateOffsetOp op,
+                  xegpu::UpdateOffsetOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    Value newOffsetForLane =
+        addOffset(rewriter, loc, adaptor.getTensorDesc(), adaptor.getOffsets(),
+                  getElemByteSize(op));
+    rewriter.replaceOp(op, newOffsetForLane);
+    return success();
+  }
+};
+
+template <typename OpType,
+          typename = std::enable_if_t<llvm::is_one_of<
+              OpType, xegpu::LoadGatherOp, xegpu::StoreScatterOp>::value>>
+class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
+  using OpConversionPattern<OpType>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto ctxt = rewriter.getContext();
+    auto tdescTy = op.getTensorDescType();
+    auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
+        ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
+    Value basePtrI64;
+    if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
+      basePtrI64 = adaptor.getSource();
+    } else {
+      basePtrI64 = adaptor.getDest();
+    }
+    Value offsets = adaptor.getOffsets();
+    Value mask = adaptor.getMask();
+    if (offsets) {
+      VectorType offsetsVecTy = dyn_cast<VectorType>(offsets.getType());
+      if (offsetsVecTy) {
+        // Offset needs be scalar.
+        return rewriter.notifyMatchFailure(op,
+                                           "Expected offsets to be a scalar.");
+      } else {
+        basePtrI64 =
+            addOffset(rewriter, loc, basePtrI64, offsets, getElemByteSize(op));
+      }
+    }
+    Value basePtrLLVM =
+        LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
+    VectorType srcOrDstVecTy = op.getValueType();
+    VectorType srcOrDstFlatVecTy = VectorType::get(
+        srcOrDstVecTy.getNumElements(), srcOrDstVecTy.getElementType());
+    Value maskForLane;
+    VectorType maskVecTy = dyn_cast<VectorType>(mask.getType());
+    if (maskVecTy) {
+      return rewriter.notifyMatchFailure(op, "Expected mask to be a scalar.");
+    } else
+      maskForLane = mask;
+    if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
+      scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, {srcOrDstVecTy},
+                                         maskForLane, true, true);
+      rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
+      Value loaded =
+          LLVM::LoadOp::create(rewriter, loc, srcOrDstFlatVecTy, basePtrLLVM);
+      loaded.getDefiningOp()->setAttr("cache_control",
+                                      xevm::LoadCacheControlAttr::get(
+                                          ctxt, translateLoadXeGPUCacheHint(
+                                                     op.getL1Hint(), op.getL3Hint())));
+      if (srcOrDstVecTy != srcOrDstFlatVecTy) {
+        loaded =
+            vector::ShapeCastOp::create(rewriter, loc, srcOrDstVecTy, loaded);
+      }
+      scf::YieldOp::create(rewriter, loc, ValueRange{loaded});
+      rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
+      // If mask is false, we yield a vector of zeros.
+      auto eTy = srcOrDstVecTy.getElementType();
+      loaded = arith::ConstantOp::create(
+          rewriter, loc,
+          eTy.isFloat()
+              ? DenseElementsAttr::get(srcOrDstVecTy, FloatAttr::get(eTy, 0.0))
+              : DenseElementsAttr::get(srcOrDstVecTy,
+                                       IntegerAttr::get(eTy, 0)));
+      scf::YieldOp::create(rewriter, loc, ValueRange{loaded});
+      rewriter.replaceOp(op, ifOp.getResult(0));
+    } else {
+      scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, maskForLane, false);
+      auto body = ifOp.getBody();
+      rewriter.setInsertionPointToStart(body);
+      VectorType valTy = op.getValue().getType();
+      Value srcFlatVec = op.getValue();
+      if (valTy != srcOrDstFlatVecTy) {
+        srcFlatVec = vector::ShapeCastOp::create(rewriter, loc,
+                                                 srcOrDstFlatVecTy, srcFlatVec);
+      }
+      auto storeOp = LLVM::StoreOp::create(rewriter, loc, srcFlatVec, basePtrLLVM);
+      storeOp.getOperation()->setAttr(
+          "cache_control",
+          xevm::StoreCacheControlAttr::get(ctxt,
+                                          translateStoreXeGPUCacheHint(
+                                              op.getL1Hint(), op.getL3Hint())));
+      rewriter.eraseOp(op);
+    }
+    return success();
+  }
+};
+
+class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::PrefetchOp op, xegpu::PrefetchOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto ctxt = rewriter.getContext();
+    auto tdescTy = op.getTensorDescType();
+    auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
+        ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
+    Value basePtrI64 = adaptor.getSource();
+    Value offsets = adaptor.getOffsets();
+    if (offsets) {
+      VectorType offsetsVecTy = dyn_cast<VectorType>(offsets.getType());
+      if (offsetsVecTy) {
+        // Offset needs be scalar.
+        return rewriter.notifyMatchFailure(op,
+                                           "Expected offsets to be a scalar.");
+      } else {
+        basePtrI64 =
+            addOffset(rewriter, loc, basePtrI64, offsets, getElemByteSize(op));
+      }
+    }
+    Value ptrLLVM =
+        LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
+    xevm::PrefetchOp::create(
+        rewriter, loc, ptrLLVM,
+        xevm::LoadCacheControlAttr::get(
+            ctxt, translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint())));
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+class FenceToXeVMPattern : public OpConversionPattern<xegpu::FenceOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::FenceOp op, xegpu::FenceOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    xevm::MemScope memScope{xevm::MemScope::WORKGROUP};
+    switch (op.getFenceScope()) {
+    case xegpu::FenceScope::Workgroup:
+      memScope = xevm::MemScope::WORKGROUP;
+      break;
+    case xegpu::FenceScope::GPU:
+      memScope = xevm::MemScope::DEVICE;
+      break;
+      llvm_unreachable("Unknown XeGPU fence scope.");
+    }
+    xevm::AddrSpace addrSpace{xevm::AddrSpace::GLOBAL};
+    switch (op.getMemoryKind()) {
+    case xegpu::MemorySpace::Global:
+      addrSpace = xevm::AddrSpace::GLOBAL;
+      break;
+    case xegpu::MemorySpace::SLM:
+      addrSpace = xevm::AddrSpace::SHARED;
+      break;
+      llvm_unreachable("Unknown XeGPU fence scope.");
+    }
+    xevm::MemfenceOp::create(rewriter, loc, memScope, addrSpace);
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
+class DpasToXeVMPattern : public OpConversionPattern<xegpu::DpasOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::DpasOp op, xegpu::DpasOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto ctxt = rewriter.getContext();
+    auto aTy = cast<VectorType>(op.getLhs().getType());
+    auto bTy = cast<VectorType>(op.getRhs().getType());
+    auto resultType = cast<VectorType>(op.getResultType());
+
+    auto encodePrecision = [&](Type type) -> xevm::ElemType {
+      if (type == rewriter.getBF16Type())
+        return xevm::ElemType::BF16;
+      else if (type == rewriter.getF16Type())
+        return xevm::ElemType::F16;
+      else if (type == rewriter.getTF32Type())
+        return xevm::ElemType::TF32;
+      else if (type.isInteger(8)) {
+        if (type.isUnsignedInteger())
+          return xevm::ElemType::U8;
+        return xevm::ElemType::S8;
+      } else if (type == rewriter.getF32Type())
+        return xevm::ElemType::F32;
+      else if (type.isInteger(32))
+        return xevm::ElemType::S32;
+      llvm_unreachable("add more support for ElemType");
+    };
+    xevm::ElemType precATy = encodePrecision(aTy.getElementType());
+    xevm::ElemType precBTy = encodePrecision(bTy.getElementType());
+    // auto precA = xevm::ElemTypeAttr::get(ctxt, precATy);
+    // auto precB = xevm::ElemTypeAttr::get(ctxt, precBTy);
+    Value c = op.getAcc();
+    if (!c) {
+      auto elementTy = resultType.getElementType();
+      Attribute initValueAttr;
+      if (isa<FloatType>(elementTy))
+        initValueAttr = FloatAttr::get(elementTy, 0.0);
+      else
+        initValueAttr = IntegerAttr::get(elementTy, 0);
+      c = arith::ConstantOp::create(
+          rewriter, loc, DenseElementsAttr::get(resultType, initValueAttr));
+    }
+
+    Value aVec = op.getLhs();
+    Value bVec = op.getRhs();
+    auto cvecty = cast<VectorType>(c.getType());
+    xevm::ElemType precCTy = encodePrecision(cvecty.getElementType());
+    xevm::ElemType precDTy = encodePrecision(resultType.getElementType());
+    // auto precC = xevm::ElemTypeAttr::get(ctxt, precCTy);
+    // auto precD = xevm::ElemTypeAttr::get(ctxt, precDTy);
+    VectorType cNty =
+        VectorType::get(cvecty.getNumElements(), cvecty.getElementType());
+    if (cvecty != cNty)
+      c = vector::ShapeCastOp::create(rewriter, loc, cNty, c);
+    // below are uArch dependent values, should move away from hardcoding
+    constexpr int32_t systolicDepth{8};
+    constexpr int32_t executionSize{16};
+    Value dpasRes = xevm::MMAOp::create(
+        rewriter, loc, cNty, aVec, bVec, c,
+        xevm::MMAShapeAttr::get(ctxt, cvecty.getNumElements(), executionSize,
+                                systolicDepth *
+                                    getNumOperandsPerDword(precATy)),
+        xevm::MMATypesAttr::get(ctxt, precDTy, precATy, precBTy, precCTy));
+    if (cvecty != cNty)
+      dpasRes = vector::ShapeCastOp::create(rewriter, loc, resultType, dpasRes);
+    rewriter.replaceOp(op, dpasRes);
+    return success();
+  }
+
+private:
+  static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
+    switch (pTy) {
+    case xevm::ElemType::TF32:
+      return 1;
+    case xevm::ElemType::BF16:
+    case xevm::ElemType::F16:
+      return 2;
+    case xevm::ElemType::U8:
+    case xevm::ElemType::S8:
+      return 4;
+    default:
+      llvm_unreachable("unsupported xevm::ElemType");
+    }
+  }
+};
+
+static std::optional<LLVM::AtomicBinOp>
+matchSimpleAtomicOp(arith::AtomicRMWKind arithKind) {
+  switch (arithKind) {
+  case arith::AtomicRMWKind::addf:
+    return LLVM::AtomicBinOp::fadd;
+  case arith::AtomicRMWKind::addi:
+    return LLVM::AtomicBinOp::add;
+  case arith::AtomicRMWKind::assign:
+    return LLVM::AtomicBinOp::xchg;
+  case arith::AtomicRMWKind::maximumf:
+    return LLVM::AtomicBinOp::fmax;
+  case arith::AtomicRMWKind::maxs:
+    return LLVM::AtomicBinOp::max;
+  case arith::AtomicRMWKind::maxu:
+    return LLVM::AtomicBinOp::umax;
+  case arith::AtomicRMWKind::minimumf:
+    return LLVM::AtomicBinOp::fmin;
+  case arith::AtomicRMWKind::mins:
+    return LLVM::AtomicBinOp::min;
+  case arith::AtomicRMWKind::minu:
+    return LLVM::AtomicBinOp::umin;
+  case arith::AtomicRMWKind::ori:
+    return LLVM::AtomicBinOp::_or;
+  case arith::AtomicRMWKind::andi:
+    return LLVM::AtomicBinOp::_and;
+  default:
+    return std::nullopt;
+  }
+  llvm_unreachable("Invalid AtomicRMWKind");
+}
+
+class AtomicRMWToXeVMPattern : public OpConversionPattern<xegpu::AtomicRMWOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::AtomicRMWOp op, xegpu::AtomicRMWOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto loc = op.getLoc();
+    auto ctxt = rewriter.getContext();
+    auto tdesc = op.getTensorDesc().getType();
+    auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
+        ctxt, getNumericXeVMAddrSpace(tdesc.getMemorySpace()));
+    Value basePtrI64 = arith::IndexCastOp::create(
+        rewriter, loc, rewriter.getI64Type(), adaptor.getTensorDesc());
+    Value basePtrLLVM =
+        LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
+    VectorType srcOrDstVecTy = cast<VectorType>(op.getValue().getType());
+    VectorType srcOrDstFlatVecTy = VectorType::get(
+        srcOrDstVecTy.getNumElements(), srcOrDstVecTy.getElementType());
+    Value srcFlatVec = vector::ShapeCastOp::create(
+        rewriter, loc, srcOrDstFlatVecTy, op.getValue());
+    auto atomicKind = matchSimpleAtomicOp(op.getKind());
+    assert(atomicKind.has_value());
+    Value resVec = srcFlatVec;
+    for (int i = 0; i < srcOrDstVecTy.getNumElements(); i++) {
+      auto val = vector::ExtractOp::create(rewriter, loc, resVec, i);
+      Value idx = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(),
+                                           rewriter.getIndexAttr(i));
+      Value currPtr =
+          LLVM::GEPOp::create(rewriter, loc, ptrTypeLLVM,
+                              srcOrDstVecTy.getElementType(), basePtrLLVM, idx);
+      Value newVal =
+          LLVM::AtomicRMWOp::create(rewriter, loc, atomicKind.value(), currPtr,
+                                    val, LLVM::AtomicOrdering::seq_cst);
+      resVec = vector::InsertOp::create(rewriter, loc, newVal, resVec, i);
+    }
+    rewriter.replaceOp(op, resVec);
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// Pass Definition
+//===----------------------------------------------------------------------===//
+
+struct ConvertXeGPUToXeVMPass
+    : public impl::ConvertXeGPUToXeVMPassBase<ConvertXeGPUToXeVMPass> {
+  using Base::Base;
+
+  void runOnOperation() override {
+    LLVMTypeConverter typeConverter(&getContext());
+    typeConverter.addConversion([&](VectorType type) -> Type {
+      unsigned rank = type.getRank();
+      auto elemType = type.getElementType();
+      // If the element type is index, convert it to i64.
+      if (llvm::isa<IndexType>(elemType))
+        elemType = IntegerType::get(&getContext(), 64);
+      // If the vector is a scalar or has a single element, return the element
+      if (rank < 1 || type.getNumElements() == 1)
+        return elemType;
+      // Otherwise, convert the vector to a flat vector type.
+      unsigned sum = 1;
+      for (unsigned i = 0; i < rank; i++) {
+        sum *= type.getShape()[i];
+      }
+      return VectorType::get(sum, elemType);
+    });
+    typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type {
+      if (type.isScattered()) {
+        return IntegerType::get(&getContext(), 64);
+      }
+      auto i32Type = IntegerType::get(&getContext(), 32);
+      return VectorType::get(8, i32Type);
+    });
+
+    auto ui64MaterializationCast = [](OpBuilder &builder, Type type,
+                                      ValueRange inputs,
+                                      Location loc) -> Value {
+      if (inputs.size() != 1)
+        return {};
+      auto input = inputs.front();
+      if (input.getType() == builder.getIntegerType(64, false)) {
+        Value cast =
+            index::CastUOp::create(builder, loc, builder.getIndexType(), input)
+                .getResult();
+        return arith::IndexCastOp::create(builder, loc, type, cast).getResult();
+      }
+      return {};
+    };
+
+    auto vector1DMaterializationCast = [](OpBuilder &builder, Type type,
+                                          ValueRange inputs,
+                                          Location loc) -> Value {
+      if (inputs.size() != 1)
+        return {};
+      auto input = inputs.front();
+      if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
+        if (vecTy.getNumElements() == 1) {
+          // If the vector has a single element, return the element type.
+          Value cast =
+              vector::ExtractOp::create(builder, loc, input, 0).getResult();
+          if (vecTy.getElementType() == builder.getIndexType())
+            cast = arith::IndexCastOp::create(builder, loc, type, cast)
+                       .getResult();
+          return cast;
+        }
+      }
+      return {};
+    };
+    typeConverter.addSourceMaterialization(ui64MaterializationCast);
+    typeConverter.addSourceMaterialization(vector1DMaterializationCast);
+    typeConverter.addTargetMaterialization(ui64MaterializationCast);
+    typeConverter.addTargetMaterialization(vector1DMaterializationCast);
+    ConversionTarget target(getContext());
+    target.addLegalDialect<xevm::XeVMDialect, LLVM::LLVMDialect,
+                           vector::VectorDialect, arith::ArithDialect,
+                           memref::MemRefDialect, gpu::GPUDialect,
+                           index::IndexDialect>();
+    target.addIllegalDialect<xegpu::XeGPUDialect>();
+
+    RewritePatternSet patterns(&getContext());
+    populateXeGPUToXeVMConversionPatterns(patterns, typeConverter);
+    scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter,
+                                                         patterns, target);
+    if (failed(applyPartialConversion(getOperation(), target,
+                                      std::move(patterns))))
+      signalPassFailure();
+  }
+};
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// Pattern Population
+//===----------------------------------------------------------------------===//
+void mlir::populateXeGPUToXeVMConversionPatterns(
+    RewritePatternSet &patterns, LLVMTypeConverter &typeConverter) {
+  patterns.add<CreateNdDescToXeVMPattern, UpdateNdOffsetToXeVMPattern,
+               LoadStorePrefetchNdToXeVMPattern<xegpu::LoadNdOp>,
+               LoadStorePrefetchNdToXeVMPattern<xegpu::StoreNdOp>,
+               LoadStorePrefetchNdToXeVMPattern<xegpu::PrefetchNdOp>>(
+      typeConverter, patterns.getContext());
+  patterns.add<CreateDescToXeVMPattern, UpdateOffsetToXeVMPattern,
+               AtomicRMWToXeVMPattern, PrefetchToXeVMPattern,
+               LoadStoreToXeVMPattern<xegpu::LoadGatherOp>,
+               LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>(
+      typeConverter, patterns.getContext());
+  patterns.add<FenceToXeVMPattern, DpasToXeVMPattern>(typeConverter,
+                                                      patterns.getContext());
+}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
new file mode 100644
index 0000000000000..4fba920f023c4
--- /dev/null
+++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
@@ -0,0 +1,48 @@
+// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
+
+gpu.module @create_nd_tdesc {
+  // CHECK-LABEL: gpu.func @create_nd_tdesc
+  // CHECK-SAME: %[[ARG0:.*]]: memref<8x16xf32, 1>, %[[ARG1:.*]]: ui64
+  // CHECK-SAME: %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index
+  gpu.func @create_nd_tdesc(%src: memref<8x16xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index,
+       %stride1: index, %stride2: index) kernel {
+         // CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index
+         // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
+        // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
+        // CHECK: %[[C0_I32:.*]] = arith.constant 0 : i32
+        // CHECK: %[[C0_I32_0:.*]] = arith.constant 0 : i32
+        // CHECK: %[[VAR2:.*]] = arith.index_cast %[[ARG3]] : index to i64
+        // CHECK: %[[VAR3:.*]] = arith.trunci %[[VAR2]] : i64 to i32
+        // CHECK: %[[VAR4:.*]] = arith.index_cast %[[ARG2]] : index to i64
+        // CHECK: %[[VAR5:.*]] = arith.trunci %[[VAR4]] : i64 to i32
+        // CHECK: %[[VAR6:.*]] = vector.bitcast %[[CST]] : vector<8xi32> to vector<4xi64>
+        // CHECK: %[[VAR7:.*]] = vector.insert %[[VAR1]], %[[VAR6]] [0] : i64 into vector<4xi64>
+        // CHECK: %[[VAR8:.*]] = vector.bitcast %[[VAR7]] : vector<4xi64> to vector<8xi32>
+        // CHECK: %[[VAR9:.*]] = vector.insert %[[VAR3]], %[[VAR8]] [2] : i32 into vector<8xi32>
+        // CHECK: %[[VAR10:.*]] = vector.insert %[[VAR5]], %[[VAR9]] [3] : i32 into vector<8xi32>
+        // CHECK: %[[VAR11:.*]] = vector.insert %[[C0_I32]], %[[VAR10]] [4] : i32 into vector<8xi32>
+        // CHECK: %[[VAR12:.*]] = vector.insert %[[C0_I32_0]], %[[VAR11]] [5] : i32 into vector<8xi32>
+        %ptr_tdesc = xegpu.create_nd_tdesc %ptr, shape:[%shape1, %shape2], strides:[%stride1, %stride2]
+            : ui64 -> !xegpu.tensor_desc<8x16xf32>
+
+        // CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<8x16xf32, 1> to memref<8x16xf32>
+        %srcce = memref.memory_space_cast %src : memref<8x16xf32, 1> to memref<8x16xf32>
+
+        // CHECK: %[[CST_1:.*]] = arith.constant dense<0> : vector<8xi32>
+        // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<8x16xf32> -> index
+        // CHECK: %[[C0_I32_2:.*]] = arith.constant 0 : i32
+        // CHECK: %[[C0_I32_3:.*]] = arith.constant 0 : i32
+        // CHECK: %[[C16_I32:.*]] = arith.constant 16 : i32
+        // CHECK: %[[C8_I32:.*]] = arith.constant 8 : i32
+        // CHECK: %[[VAR13:.*]] = arith.index_castui %[[INTPTR]] : index to i64
+        // CHECK: %[[VAR14:.*]] = vector.bitcast %[[CST_1]] : vector<8xi32> to vector<4xi64>
+        // CHECK: %[[VAR15:.*]] = vector.insert %[[VAR13]], %[[VAR14]] [0] : i64 into vector<4xi64>
+        // CHECK: %[[VAR16:.*]] = vector.bitcast %[[VAR15]] : vector<4xi64> to vector<8xi32>
+        // CHECK: %[[VAR17:.*]] = vector.insert %[[C16_I32]], %[[VAR16]] [2] : i32 into vector<8xi32>
+        // CHECK: %[[VAR18:.*]] = vector.insert %[[C8_I32]], %[[VAR17]] [3] : i32 into vector<8xi32>
+        // CHECK: %[[VAR19:.*]] = vector.insert %[[C0_I32_2]], %[[VAR18]] [4] : i32 into vector<8xi32>
+        // CHECK: %[[VAR20:.*]] = vector.insert %[[C0_I32_3]], %[[VAR19]] [5] : i32 into vector<8xi32>
+        %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+        gpu.return
+    }
+}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
new file mode 100644
index 0000000000000..15940fc4aca26
--- /dev/null
+++ b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
@@ -0,0 +1,17 @@
+// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
+
+#sg_map_a_f16 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+#sg_map_b_f16 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
+#sg_map_c_f32 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
+
+gpu.module @load_store_check {
+    //CHECK: func.func @dpas(%[[ARG0:.*]]: vector<8xf16>, %[[ARG1:.*]]: vector<16xf16>, %[[ARG2:.*]]: vector<8xf32>) -> vector<8xf32>
+    func.func @dpas(%a_loaded: vector<8xf16>, %b_loaded: vector<16xf16>, %c_loaded: vector<8xf32>) -> vector<8xf32> {
+        // Loads are checked in a separate test.
+        // CHECK: %[[D:.*]] = xevm.mma %[[ARG0]], %[[ARG1]], %[[ARG2]] {shape = <m = 8, n = 16, k = 16>, types = <d = f32, a = f16, b = f16, c = f32>}
+        // CHECK-SAME:    : (vector<8xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32>
+        %d = xegpu.dpas %a_loaded, %b_loaded, %c_loaded {a_layout = #sg_map_a_f16, b_layout = #sg_map_b_f16, c_layout = #sg_map_c_f32}
+            : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32>
+        return %d : vector<8xf32>
+    }
+}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/fence.mlir b/mlir/test/Conversion/XeGPUToXeVM/fence.mlir
new file mode 100644
index 0000000000000..cedfcace398a6
--- /dev/null
+++ b/mlir/test/Conversion/XeGPUToXeVM/fence.mlir
@@ -0,0 +1,15 @@
+// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
+
+gpu.module @fence_check {
+    gpu.func @fence(%dst: memref<8x16xf32, 1>) kernel {
+        %tid_x = gpu.thread_id x
+        %tid_x_i32 = arith.index_cast %tid_x : index to i32
+        %tid_x_f32 = arith.sitofp %tid_x_i32 : i32 to f32
+
+        // CHECK: xevm.memfence <{addrspace = #xevm.addr_space<global>, scope = #xevm.mem_scope<workgroup>}>
+        xegpu.fence memory_kind = global, fence_scope = workgroup
+        %c0 = arith.constant 0 : index
+        memref.store %tid_x_f32, %dst[%c0, %c0] : memref<8x16xf32, 1>
+        gpu.return
+    }
+}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
new file mode 100644
index 0000000000000..c692da632d458
--- /dev/null
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
@@ -0,0 +1,71 @@
+// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
+
+gpu.module @load_store_check {
+    gpu.func @load_store(%src: memref<8x16xf32, 1>, %dst: memref<8x16xf32, 1>) kernel {
+        %srcce = memref.memory_space_cast %src : memref<8x16xf32, 1> to memref<8x16xf32>
+        %dstte = memref.memory_space_cast %dst : memref<8x16xf32, 1> to memref<8x16xf32>
+
+        // CHECK: %[[LD_PTR_AS_I64:.*]] = arith.index_castui {{.*}} : index to i64
+        // CHECK: %[[LD_CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64>
+        // CHECK: %[[LD_DESC_0:.*]] = vector.insert %[[LD_PTR_AS_I64]], %[[LD_CREATE_DESC_I64]] [0] : i64 into vector<4xi64>
+        // CHECK: %[[LD_DESC_1:.*]] = vector.bitcast %[[LD_DESC_0]] : vector<4xi64> to vector<8xi32>
+        // CHECK: %[[LD_DESC_2:.*]] = vector.insert {{.*}}, %[[LD_DESC_1]] [2] : i32 into vector<8xi32>
+        // CHECK: %[[LD_DESC_3:.*]] = vector.insert {{.*}}, %[[LD_DESC_2]] [3] : i32 into vector<8xi32>
+        // CHECK: %[[LD_DESC_4:.*]] = vector.insert {{.*}}, %[[LD_DESC_3]] [4] : i32 into vector<8xi32>
+        // CHECK: %[[LD_DESC:.*]] = vector.insert {{.*}}, %[[LD_DESC_4]] [5] : i32 into vector<8xi32>
+        %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+
+
+        //CHECK: %[[LD_DESC_I64:.*]] = vector.bitcast %[[LD_DESC]] : vector<8xi32> to vector<4xi64>
+        //CHECK: %[[LD_INTPTR:.*]] = vector.extract %[[LD_DESC_I64]][0] : i64 from vector<4xi64>
+        //CHECK: %[[LD_BASE_W:.*]] = vector.extract %[[LD_DESC]][2] : i32 from vector<8xi32>
+        //CHECK: %[[LD_BASE_H:.*]] = vector.extract %[[LD_DESC]][3] : i32 from vector<8xi32>
+        //CHECK: %[[LD_TILE_W:.*]] = arith.constant 0 : i32
+        //CHECK: %[[LD_TILE_H:.*]] = arith.constant 0 : i32
+        //CHECK: %[[LD_LLVMPTR:.*]] = llvm.inttoptr %[[LD_INTPTR]] : i64 to !llvm.ptr<1>
+        //CHECK: %[[LD_SIZEOF_F32:.*]] = arith.constant 4 : i32
+        //CHECK: %[[LD_BASE_ROW_IN_BYTES:.*]] = arith.muli %[[LD_BASE_W]], %[[LD_SIZEOF_F32]] : i32
+        //CHECK: %[[LD_LOADED_I32:.*]] = xevm.blockload2d %[[LD_LLVMPTR]], %[[LD_BASE_ROW_IN_BYTES]],
+        //CHECK-SAME: %[[LD_BASE_H]], %[[LD_BASE_ROW_IN_BYTES]], %[[LD_TILE_W]], %[[LD_TILE_H]]
+        //CHECK-SAME: <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
+        //CHECK-SAME:   pack_register = false, tile_height = 8 : i32, tile_width = 16 : i32, transpose = false,
+        //CHECK-SAME:   v_blocks = 1 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
+        %loaded = xegpu.load_nd %src_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+            : !xegpu.tensor_desc<8x16xf32> -> vector<8xf32>
+        //CHECK: %[[LD_LOADED_F32:.*]] = vector.bitcast %[[LD_LOADED_I32]] : vector<8xi32> to vector<8xf32>
+
+        %tid_x = gpu.thread_id x
+        %tid_x_i32 = arith.index_cast %tid_x : index to i32
+        %tid_x_f32 = arith.sitofp %tid_x_i32 : i32 to f32
+        //CHECK: %[[LOADED_F32_MODIFIED:.*]] = vector.insert %{{.*}}, %[[LD_LOADED_F32]] [0] : f32 into vector<8xf32>
+        %loaded_modified = vector.insert %tid_x_f32, %loaded[0] : f32 into vector<8xf32>
+
+        // CHECK: %[[PTR_AS_I64:.*]] = arith.index_castui {{.*}} : index to i64
+        // CHECK: %[[CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64>
+        // CHECK: %[[DESC_0:.*]] = vector.insert %[[PTR_AS_I64]], %[[CREATE_DESC_I64]] [0] : i64 into vector<4xi64>
+        // CHECK: %[[DESC_1:.*]] = vector.bitcast %[[DESC_0]] : vector<4xi64> to vector<8xi32>
+        // CHECK: %[[DESC_2:.*]] = vector.insert {{.*}}, %[[DESC_1]] [2] : i32 into vector<8xi32>
+        // CHECK: %[[DESC_3:.*]] = vector.insert {{.*}}, %[[DESC_2]] [3] : i32 into vector<8xi32>
+        // CHECK: %[[DESC_4:.*]] = vector.insert {{.*}}, %[[DESC_3]] [4] : i32 into vector<8xi32>
+        // CHECK: %[[DESC:.*]] = vector.insert {{.*}}, %[[DESC_4]] [5] : i32 into vector<8xi32>
+        %dst_tdesc = xegpu.create_nd_tdesc %dstte : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
+
+        //CHECK: %[[DESC_I64:.*]] = vector.bitcast %[[DESC]] : vector<8xi32> to vector<4xi64>
+        //CHECK: %[[INTPTR:.*]] = vector.extract %[[DESC_I64]][0] : i64 from vector<4xi64>
+        //CHECK: %[[BASE_W:.*]] = vector.extract %[[DESC]][2] : i32 from vector<8xi32>
+        //CHECK: %[[BASE_H:.*]] = vector.extract %[[DESC]][3] : i32 from vector<8xi32>
+        //CHECK: %[[TILE_W:.*]] = arith.constant 0 : i32
+        //CHECK: %[[TILE_H:.*]] = arith.constant 0 : i32
+        //CHECK: %[[LLVMPTR:.*]] = llvm.inttoptr %[[INTPTR]] : i64 to !llvm.ptr<1>
+        //CHECK: %[[SIZEOF_F32:.*]] = arith.constant 4 : i32
+        //CHECK: %[[BASE_ROW_IN_BYTES:.*]] = arith.muli %[[BASE_W]], %[[SIZEOF_F32]] : i32
+        //CHECK: %[[FLAT_VALUE_I32:.*]] = vector.bitcast %[[LOADED_F32_MODIFIED]] : vector<8xf32> to vector<8xi32>
+        //CHECK: xevm.blockstore2d %[[LLVMPTR]], %[[BASE_ROW_IN_BYTES]], %[[BASE_H]], %[[BASE_ROW_IN_BYTES]],
+        //CHECK-SAME: %[[TILE_W]], %[[TILE_H]], %[[FLAT_VALUE_I32]]
+        //CHECK-SAME: <{cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
+        //CHECK-SAME:   tile_height = 8 : i32, tile_width = 16 : i32}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>)
+        xegpu.store_nd %loaded_modified, %dst_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
+            : vector<8xf32>, !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>>
+        gpu.return
+    }
+}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir
new file mode 100644
index 0000000000000..f6d023307313a
--- /dev/null
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir
@@ -0,0 +1,357 @@
+// RUN: mlir-opt %s --split-input-file -convert-xegpu-to-xevm | FileCheck %s
+
+gpu.module @test {
+// CHECK-LABEL: @load_gather_ui64_src_constant_offset
+// CHECK-SAME: %[[ARG0:.*]]: ui64
+gpu.func @load_gather_ui64_src_constant_offset(%src: ui64) {
+  // CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui64 to index
+  // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
+  // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
+  // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
+  // CHECK: %[[VAR3:.*]] = arith.index_cast %[[VAR2]] : index to i64
+  %0 = arith.constant dense<0> : vector<1xindex>
+  // CHECK: %[[CST_0:.*]] = arith.constant dense<true> : vector<1xi1>
+  // CHECK: %[[VAR4:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1>
+  %1 = arith.constant dense<1>: vector<1xi1>
+  // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
+  // CHECK: %[[VAR5:.*]] = arith.muli %[[VAR3]], %[[C4_I64]] : i64
+  // CHECK: %[[VAR6:.*]] = arith.addi %[[VAR1]], %[[VAR5]] : i64
+  %2 = xegpu.create_tdesc %src, %0 : ui64, vector<1xindex>
+      -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+  // CHECK: %[[VAR7:.*]] = llvm.inttoptr %[[VAR6]] : i64 to !llvm.ptr<1>
+  // CHECK: %[[VAR8:.*]] = scf.if %[[VAR4]] -> (vector<2xf32>) {
+  // CHECK:      %[[VAR9:.*]] = llvm.load %[[VAR7]] {cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}
+  // CHECK-SAME:     : !llvm.ptr<1> -> vector<2xf32>
+  // CHECK:      scf.yield %[[VAR9]] : vector<2xf32>
+  // CHECK:    } else {
+  // CHECK:      %[[CST_1:.*]] = arith.constant dense<0.000000e+00> : vector<2xf32>
+  // CHECK:      scf.yield %[[CST_1]] : vector<2xf32>
+  %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+      : !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<1xi1> -> vector<2xf32>
+  gpu.return
+}
+}
+// -----
+
+gpu.module @test {
+// CHECK-LABEL: @load_gather_memref_src_constant_offset
+// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>
+gpu.func @load_gather_memref_src_constant_offset(%src: memref<256xf32>) {
+  // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
+  // CHECK: %[[VAR0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
+  // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
+  %0 = arith.constant dense<0> : vector<1xindex>
+  // CHECK: %[[CST_0:.*]] = arith.constant dense<true> : vector<1xi1>
+  // CHECK: %[[VAR2:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1>
+  %1 = arith.constant dense<1>: vector<1xi1>
+  // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
+  // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64
+  // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
+  // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64
+  // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR3]], %[[VAR4]] : i64
+  %2 = xegpu.create_tdesc %src, %0 : memref<256xf32>, vector<1xindex>
+      -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
+  // CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1>
+  // CHECK: %[[VAR7:.*]] = scf.if %[[VAR2]] -> (f32) {
+  // CHECK:      %[[VAR8:.*]] = llvm.load %[[VAR6]] {cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}
+  // CHECK-SAME:     : !llvm.ptr<1> -> vector<1xf32>
+  // CHECK:      %[[VAR9:.*]] = vector.extract %[[VAR8]][0] : f32 from vector<1xf32>
+  // CHECK:      scf.yield %[[VAR9]] : f32
+  // CHECK:    } else {
+  // CHECK:      %[[CST_1:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32>
+  // CHECK:      %[[VAR8:.*]] = vector.extract %[[CST_1:.*]][0] : f32 from vector<1xf32>
+  // CHECK:      scf.yield %[[VAR8]] : f32
+  %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+      : !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>, vector<1xi1> -> vector<1xf32>
+  gpu.return
+}
+}
+// -----
+
+gpu.module @test {
+// CHECK-LABEL: @load_gather_memref_src_value_offset
+// CHECK-SAME: %[[ARG0:.*]]: memref<256xf16>, %[[ARG1:.*]]: vector<1xindex>
+gpu.func @load_gather_memref_src_value_offset(%src: memref<256xf16>, %offset: vector<1xindex>) {
+  // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex>
+  // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
+  // CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<1xi1>
+  // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : i1 from vector<1xi1>
+  %1 = arith.constant dense<1>: vector<1xi1>
+  // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf16> -> index
+  // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64
+  // CHECK: %[[C2_I64:.*]] = arith.constant 2 : i64
+  // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C2_I64]] : i64
+  // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR3]], %[[VAR4]] : i64
+  %2 = xegpu.create_tdesc %src, %offset : memref<256xf16>, vector<1xindex>
+      -> !xegpu.tensor_desc<1x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8>>
+  // CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1>
+  // CHECK: %[[VAR7:.*]] = scf.if %[[VAR2]] -> (vector<8xf16>) {
+  // CHECK:      %[[VAR8:.*]] = llvm.load %[[VAR6]] {cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}
+  // CHECK-SAME:     : !llvm.ptr<1> -> vector<8xf16>
+  // CHECK:      scf.yield %[[VAR8]] : vector<8xf16>
+  // CHECK:    } else {
+  // CHECK:      %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<8xf16>
+  // CHECK:      scf.yield %[[CST_0]] : vector<8xf16>
+  %3 = xegpu.load %2, %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+      : !xegpu.tensor_desc<1x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8>>, vector<1xi1> -> vector<8xf16>
+  gpu.return
+}
+}
+// -----
+
+gpu.module @test {
+// CHECK-LABEL: @load_gather_memref_src_load_offset
+// CHECK-SAME: %[[ARG0:.*]]: memref<256xf16>, %[[ARG1:.*]]: vector<1xindex>, %[[ARG2:.*]]: vector<1xindex>
+gpu.func @load_gather_memref_src_load_offset(%src: memref<256xf16>, %offset1: vector<1xindex>, %offset2: vector<1xindex>) {
+  // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG2]][0] : index from vector<1xindex>
+  // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
+  // CHECK: %[[VAR2:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex>
+  // CHECK: %[[VAR3:.*]] = arith.index_cast %[[VAR2]] : index to i64
+  // CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<1xi1>
+  // CHECK: %[[VAR4:.*]] = vector.extract %[[CST]][0] : i1 from vector<1xi1>
+  %1 = arith.constant dense<1>: vector<1xi1>
+  // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf16> -> index
+  // CHECK: %[[VAR5:.*]] = arith.index_castui %[[INTPTR]] : index to i64
+  // CHECK: %[[C2_I64:.*]] = arith.constant 2 : i64
+  // CHECK: %[[VAR6:.*]] = arith.muli %[[VAR3]], %[[C2_I64]] : i64
+  // CHECK: %[[VAR7:.*]] = arith.addi %[[VAR5]], %[[VAR6]] : i64
+  %2 = xegpu.create_tdesc %src, %offset1 : memref<256xf16>, vector<1xindex>
+      -> !xegpu.tensor_desc<1x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8>>
+  // CHECK: %[[C2_I64_0:.*]] = arith.constant 2 : i64
+  // CHECK: %[[VAR8:.*]] = arith.muli %[[VAR1]], %[[C2_I64_0]] : i64
+  // CHECK: %[[VAR9:.*]] = arith.addi %[[VAR7]], %[[VAR8]] : i64
+  // CHECK: %[[VAR10:.*]] = llvm.inttoptr %[[VAR9]] : i64 to !llvm.ptr<1>
+  // CHECK: %[[VAR11:.*]] = scf.if %[[VAR4]] -> (vector<8xf16>) {
+  // CHECK:      %[[VAR12:.*]] = llvm.load %[[VAR10]] {cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}
+  // CHECK-SAME:     : !llvm.ptr<1> -> vector<8xf16>
+  // CHECK:      scf.yield %[[VAR12]] : vector<8xf16>
+  // CHECK:    } else {
+  // CHECK:      %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<8xf16>
+  // CHECK:      scf.yield %[[CST_0]] : vector<8xf16>
+  %3 = xegpu.load %2[%offset2], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+      : !xegpu.tensor_desc<1x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8>>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
+  gpu.return
+}
+}
+// -----
+
+gpu.module @test {
+// CHECK-LABEL: @store_scatter_ui64_src_constant_offset
+// CHECK-SAME: %[[ARG0:.*]]: ui64
+gpu.func @store_scatter_ui64_src_constant_offset(%src: ui64) {
+  // CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui64 to index
+  // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
+  // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
+  // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
+  // CHECK: %[[VAR3:.*]] = arith.index_cast %[[VAR2]] : index to i64
+  %0 = arith.constant dense<0> : vector<1xindex>
+  // CHECK: %[[CST_0:.*]] = arith.constant dense<true> : vector<1xi1>
+  // CHECK: %[[VAR4:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1>
+  %1 = arith.constant dense<1>: vector<1xi1>
+  // CHECK: %[[CST_1:.*]] = arith.constant dense<2.900000e+00> : vector<2xf32>
+  %2 = arith.constant dense<2.9>: vector<2xf32>
+  // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
+  // CHECK: %[[VAR5:.*]] = arith.muli %[[VAR3]], %[[C4_I64]] : i64
+  // CHECK: %[[VAR6:.*]] = arith.addi %[[VAR1]], %[[VAR5]] : i64
+  %3 = xegpu.create_tdesc %src, %0 : ui64, vector<1xindex>
+      -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+  // CHECK: %[[VAR7:.*]] = llvm.inttoptr %[[VAR6]] : i64 to !llvm.ptr<1>
+  // CHECK:    scf.if %[[VAR4]] {
+  // CHECK:      llvm.store %[[CST_1]], %[[VAR7]] {cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>}
+  // CHECK-SAME:     : vector<2xf32>, !llvm.ptr<1>
+  xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
+      : vector<2xf32>, !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<1xi1>
+  gpu.return
+}
+}
+// -----
+
+gpu.module @test {
+// CHECK-LABEL: @store_scatter_memref_src_constant_offset
+// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>
+gpu.func @store_scatter_memref_src_constant_offset(%src: memref<256xf32>) {
+  // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
+  // CHECK: %[[VAR0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
+  // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
+  %0 = arith.constant dense<0> : vector<1xindex>
+  // CHECK: %[[CST_0:.*]] = arith.constant dense<true> : vector<1xi1>
+  // CHECK: %[[VAR2:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1>
+  %1 = arith.constant dense<1>: vector<1xi1>
+  // CHECK: %[[CST_1:.*]] = arith.constant dense<2.900390e+00> : vector<2xf16>
+  %2 = arith.constant dense<2.9>: vector<2xf16>
+  // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
+  // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64
+  // CHECK: %[[C2_I64:.*]] = arith.constant 2 : i64
+  // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C2_I64]] : i64
+  // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR3]], %[[VAR4]] : i64
+  %3 = xegpu.create_tdesc %src, %0 : memref<256xf32>, vector<1xindex>
+      -> !xegpu.tensor_desc<1x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+  // CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1>
+  // CHECK: scf.if %[[VAR2]] {
+  // CHECK:      llvm.store %[[CST_1]], %[[VAR6]] {cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>}
+  // CHECK-SAME:     : vector<2xf16>, !llvm.ptr<1>
+  xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
+      : vector<2xf16>, !xegpu.tensor_desc<1x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<1xi1>
+  gpu.return
+}
+}
+// -----
+
+gpu.module @test {
+// CHECK-LABEL: @store_scatter_memref_src_value_offset
+// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>, %[[ARG1:.*]]: vector<1xindex>
+gpu.func @store_scatter_memref_src_value_offset(%src: memref<256xf32>, %offset: vector<1xindex>) {
+  // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex>
+  // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
+  // CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<1xi1>
+  // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : i1 from vector<1xi1>
+  %1 = arith.constant dense<1>: vector<1xi1>
+  // CHECK: %[[CST_0:.*]] = arith.constant dense<2.900000e+00> : vector<1xf32>
+  %2 = arith.constant dense<2.9>: vector<1xf32>
+  // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
+  // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64
+  // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
+  // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64
+  // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR3]], %[[VAR4]] : i64
+  %3 = xegpu.create_tdesc %src, %offset : memref<256xf32>, vector<1xindex>
+      -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
+  // CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1>
+  // CHECK: scf.if %[[VAR2]] {
+  // CHECK:      llvm.store %[[CST_0]], %[[VAR6]] {cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>}
+  // CHECK-SAME:     : vector<1xf32>, !llvm.ptr<1>
+  xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
+      : vector<1xf32>, !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>, vector<1xi1>
+  gpu.return
+}
+}
+// -----
+
+gpu.module @test {
+// CHECK-LABEL: @store_scatter_memref_src_store_offset
+// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>, %[[ARG1:.*]]: vector<1xindex>, %[[ARG2:.*]]: vector<1xindex>
+gpu.func @store_scatter_memref_src_store_offset(%src: memref<256xf32>, %offset: vector<1xindex>, %offset2: vector<1xindex>) {
+  // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG2]][0] : index from vector<1xindex>
+  // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
+  // CHECK: %[[VAR2:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex>
+  // CHECK: %[[VAR3:.*]] = arith.index_cast %[[VAR2]] : index to i64
+  // CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<1xi1>
+  // CHECK: %[[VAR4:.*]] = vector.extract %[[CST]][0] : i1 from vector<1xi1>
+  %1 = arith.constant dense<1>: vector<1xi1>
+  // CHECK: %[[CST_0:.*]] = arith.constant dense<2.900000e+00> : vector<1xf32>
+  %2 = arith.constant dense<2.9>: vector<1xf32>
+  // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
+  // CHECK: %[[VAR5:.*]] = arith.index_castui %[[INTPTR]] : index to i64
+  // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
+  // CHECK: %[[VAR6:.*]] = arith.muli %[[VAR3]], %[[C4_I64]] : i64
+  // CHECK: %[[VAR7:.*]] = arith.addi %[[VAR5]], %[[VAR6]] : i64
+  %3 = xegpu.create_tdesc %src, %offset : memref<256xf32>, vector<1xindex>
+      -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
+  // CHECK: %[[C4_I64_1:.*]] = arith.constant 4 : i64
+  // CHECK: %[[VAR8:.*]] = arith.muli %[[VAR1]], %[[C4_I64_1]] : i64
+  // CHECK: %[[VAR9:.*]] = arith.addi %[[VAR7]], %[[VAR8]] : i64
+  // CHECK: %[[VAR10:.*]] = llvm.inttoptr %[[VAR9]] : i64 to !llvm.ptr<1>
+  // CHECK: scf.if %[[VAR4]] {
+  // CHECK:      llvm.store %[[CST_0]], %[[VAR10]] {cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>}
+  // CHECK-SAME:     : vector<1xf32>, !llvm.ptr<1>
+  xegpu.store %2, %3[%offset2], %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
+      : vector<1xf32>, !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>, vector<1xindex>, vector<1xi1>
+  gpu.return
+}
+}
+// -----
+
+gpu.module @test {
+// CHECK-LABEL: @prefetch_ui64_src_constant_offset
+// CHECK-SAME: %[[ARG0:.*]]: ui64
+gpu.func @prefetch_ui64_src_constant_offset(%src: ui64) {
+  // CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui64 to index
+  // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
+  // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
+  // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
+  // CHECK: %[[VAR3:.*]] = arith.index_cast %[[VAR2]] : index to i64
+  %0 = arith.constant dense<0> : vector<1xindex>
+  // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
+  // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR3]], %[[C4_I64]] : i64
+  // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR1]], %[[VAR4]] : i64
+  %1 = xegpu.create_tdesc %src, %0 : ui64, vector<1xindex>
+      -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+  // CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1>
+  // CHECK: xevm.prefetch %[[VAR6]] <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}> : (!llvm.ptr<1>)
+  xegpu.prefetch %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+      : !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+  gpu.return
+}
+}
+// -----
+
+gpu.module @test {
+// CHECK-LABEL: @prefetch_memref_src_constant_offset
+// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>
+gpu.func @prefetch_memref_src_constant_offset(%src: memref<256xf32>) {
+  // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
+  // CHECK: %[[VAR0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
+  // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
+  %0 = arith.constant dense<0> : vector<1xindex>
+  // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
+  // CHECK: %[[VAR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
+  // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
+  // CHECK: %[[VAR3:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64
+  // CHECK: %[[VAR4:.*]] = arith.addi %[[VAR2]], %[[VAR3]] : i64
+  %1 = xegpu.create_tdesc %src, %0 : memref<256xf32>, vector<1xindex>
+      -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+  // CHECK: %[[VAR5:.*]] = llvm.inttoptr %[[VAR4]] : i64 to !llvm.ptr<1>
+  // CHECK: xevm.prefetch %[[VAR5]] <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}> : (!llvm.ptr<1>)
+  xegpu.prefetch %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+      : !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+  gpu.return
+}
+}
+// -----
+
+gpu.module @test {
+// CHECK-LABEL: @prefetch_memref_src_value_offset
+// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>, %[[ARG1:.*]]: vector<1xindex>
+gpu.func @prefetch_memref_src_value_offset(%src: memref<256xf32>, %offset: vector<1xindex>) {
+  // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex>
+  // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
+  // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
+  // CHECK: %[[VAR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
+  // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
+  // CHECK: %[[VAR3:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64
+  // CHECK: %[[VAR4:.*]] = arith.addi %[[VAR2]], %[[VAR3]] : i64
+  %1 = xegpu.create_tdesc %src, %offset : memref<256xf32>, vector<1xindex>
+      -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+  // CHECK: %[[VAR5:.*]] = llvm.inttoptr %[[VAR4]] : i64 to !llvm.ptr<1>
+  // CHECK: xevm.prefetch %[[VAR5]] <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}> : (!llvm.ptr<1>)
+  xegpu.prefetch %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+      : !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+  gpu.return
+}
+}
+// -----
+
+gpu.module @test {
+// CHECK-LABEL: @prefetch_memref_src_prefetch_offset
+// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>, %[[ARG1:.*]]: vector<1xindex>, %[[ARG2:.*]]: vector<1xindex>
+gpu.func @prefetch_memref_src_prefetch_offset(%src: memref<256xf32>, %offset: vector<1xindex>, %offset2: vector<1xindex>) {
+  // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG2]][0] : index from vector<1xindex>
+  // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
+  // CHECK: %[[VAR2:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex>
+  // CHECK: %[[VAR3:.*]] = arith.index_cast %[[VAR2]] : index to i64
+  // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
+  // CHECK: %[[VAR4:.*]] = arith.index_castui %[[INTPTR]] : index to i64
+  // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
+  // CHECK: %[[VAR5:.*]] = arith.muli %[[VAR3]], %[[C4_I64]] : i64
+  // CHECK: %[[VAR6:.*]] = arith.addi %[[VAR4]], %[[VAR5]] : i64
+  %1 = xegpu.create_tdesc %src, %offset : memref<256xf32>, vector<1xindex>
+      -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+  // CHECK: %[[C4_I64_0:.*]] = arith.constant 4 : i64
+  // CHECK: %[[VAR7:.*]] = arith.muli %[[VAR1]], %[[C4_I64_0]] : i64
+  // CHECK: %[[VAR8:.*]] = arith.addi %[[VAR6]], %[[VAR7]] : i64
+  // CHECK: %[[VAR9:.*]] = llvm.inttoptr %[[VAR8]] : i64 to !llvm.ptr<1>
+  // CHECK: xevm.prefetch %[[VAR9]] <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}> : (!llvm.ptr<1>)
+  xegpu.prefetch %1[%offset2] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+      : !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<1xindex>
+  gpu.return
+}
+}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir
new file mode 100644
index 0000000000000..8513b4f9857fb
--- /dev/null
+++ b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir
@@ -0,0 +1,40 @@
+// RUN: mlir-opt -convert-xegpu-to-xevm -split-input-file %s | FileCheck %s
+
+gpu.module @fence_check {
+    gpu.func @fence(%src: memref<8x16xf32, 1>, %dst: memref<8x16xf32, 1>) kernel {
+        %srcce = memref.memory_space_cast %src : memref<8x16xf32, 1> to memref<8x16xf32>
+        %dstte = memref.memory_space_cast %dst : memref<8x16xf32, 1> to memref<8x16xf32>
+
+        // CHECK: %[[LD_PTR_AS_I64:.*]] = arith.index_castui {{.*}} : index to i64
+        // CHECK: %[[LD_CREATE_DESC_I64:.*]] = vector.bitcast {{.*}} : vector<8xi32> to vector<4xi64>
+        // CHECK: %[[LD_DESC_0:.*]] = vector.insert %[[LD_PTR_AS_I64]], %[[LD_CREATE_DESC_I64]] [0] : i64 into vector<4xi64>
+        // CHECK: %[[LD_DESC_1:.*]] = vector.bitcast %[[LD_DESC_0]] : vector<4xi64> to vector<8xi32>
+        // CHECK: %[[LD_DESC_2:.*]] = vector.insert {{.*}}, %[[LD_DESC_1]] [2] : i32 into vector<8xi32>
+        // CHECK: %[[LD_DESC_3:.*]] = vector.insert {{.*}}, %[[LD_DESC_2]] [3] : i32 into vector<8xi32>
+        // CHECK: %[[LD_DESC_4:.*]] = vector.insert {{.*}}, %[[LD_DESC_3]] [4] : i32 into vector<8xi32>
+        // CHECK: %[[LD_DESC:.*]] = vector.insert {{.*}}, %[[LD_DESC_4]] [5] : i32 into vector<8xi32>
+        %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32,
+            #xegpu.block_tdesc_attr<memory_space = global>, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+
+        //CHECK: %[[LD_DESC_I64:.*]] = vector.bitcast %[[LD_DESC]] : vector<8xi32> to vector<4xi64>
+        //CHECK: %[[PREF_INTPTR:.*]] = vector.extract %[[LD_DESC_I64]][0] : i64 from vector<4xi64>
+        //CHECK: %[[PREF_BASE_W:.*]] = vector.extract %[[LD_DESC]][2] : i32 from vector<8xi32>
+        //CHECK: %[[PREF_BASE_H:.*]] = vector.extract %[[LD_DESC]][3] : i32 from vector<8xi32>
+        //CHECK: %[[PREF_TILE_W:.*]] = arith.constant 0 : i32
+        //CHECK: %[[PREF_TILE_H:.*]] = arith.constant 0 : i32
+        //CHECK: %[[PREF_LLVMPTR:.*]] = llvm.inttoptr %[[PREF_INTPTR]] : i64 to !llvm.ptr<1>
+        //CHECK: %[[PREF_SIZEOF_F32:.*]] = arith.constant 4 : i32
+        //CHECK: %[[PREF_BASE_ROW_IN_BYTES:.*]] = arith.muli %[[PREF_BASE_W]], %[[PREF_SIZEOF_F32]] : i32
+        //CHECK: xevm.blockprefetch2d %[[PREF_LLVMPTR]], %[[PREF_BASE_ROW_IN_BYTES]], %[[PREF_BASE_H]],
+        //CHECK-SAME:   %[[PREF_BASE_ROW_IN_BYTES]], %[[PREF_TILE_W]], %[[PREF_TILE_H]]
+        //CHECK-SAME:   <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>, elem_size_in_bits = 32 : i32,
+        //CHECK-SAME:     tile_height = 8 : i32, tile_width = 16 : i32, v_blocks = 1 : i32}>
+        //CHECK-SAME:   : (!llvm.ptr<1>, i32, i32, i32, i32, i32)
+        xegpu.prefetch_nd %src_tdesc[0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+            : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = global>,
+                  #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
+
+        gpu.return
+    }
+}
+
diff --git a/mlir/test/Conversion/XeGPUToXeVM/update_offset.mlir b/mlir/test/Conversion/XeGPUToXeVM/update_offset.mlir
new file mode 100644
index 0000000000000..e9d7fd4cf40a6
--- /dev/null
+++ b/mlir/test/Conversion/XeGPUToXeVM/update_offset.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
+
+gpu.module @update_offset {
+  // CHECK-LABEL: gpu.func @update_offset
+  // CHECK-SAME: %[[ARG0:.*]]: memref<128xf32>
+  gpu.func @update_offset(%src: memref<128xf32>) kernel {
+    // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
+    %offset = arith.constant dense<0> : vector<1xindex>
+    // CHECK: %[[VAR0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
+    // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
+    // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<128xf32> -> index
+    // CHECK: %[[VAR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
+    // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
+    // CHECK: %[[VAR3:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64
+    // CHECK: %[[VAR4:.*]] = arith.addi %[[VAR2]], %[[VAR3]] : i64
+    %src_tdesc = xegpu.create_tdesc %src, %offset : memref<128xf32>, vector<1xindex>
+        -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
+    // CHECK: %[[C4_I64_0:.*]] = arith.constant 4 : i64
+    // CHECK: %[[VAR5:.*]] = arith.muli %[[VAR1]], %[[C4_I64_0]] : i64
+    // CHECK: %[[VAR6:.*]] = arith.addi %[[VAR4]], %[[VAR5]] : i64
+    %new_tdesc = xegpu.update_offset %src_tdesc, %offset : !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
+        , vector<1xindex>
+    gpu.return
+  }
+}

>From 97d5eceb4c15fbc508cc081d88cce3f540271caf Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Wed, 20 Aug 2025 15:25:41 +0000
Subject: [PATCH 02/10] Apply clang format.

---
 .../lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 18 +++++++++---------
 1 file changed, 9 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 380409afbc62e..32983152ef5bd 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -558,10 +558,10 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
       rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
       Value loaded =
           LLVM::LoadOp::create(rewriter, loc, srcOrDstFlatVecTy, basePtrLLVM);
-      loaded.getDefiningOp()->setAttr("cache_control",
-                                      xevm::LoadCacheControlAttr::get(
-                                          ctxt, translateLoadXeGPUCacheHint(
-                                                     op.getL1Hint(), op.getL3Hint())));
+      loaded.getDefiningOp()->setAttr(
+          "cache_control", xevm::LoadCacheControlAttr::get(
+                               ctxt, translateLoadXeGPUCacheHint(
+                                         op.getL1Hint(), op.getL3Hint())));
       if (srcOrDstVecTy != srcOrDstFlatVecTy) {
         loaded =
             vector::ShapeCastOp::create(rewriter, loc, srcOrDstVecTy, loaded);
@@ -588,12 +588,12 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
         srcFlatVec = vector::ShapeCastOp::create(rewriter, loc,
                                                  srcOrDstFlatVecTy, srcFlatVec);
       }
-      auto storeOp = LLVM::StoreOp::create(rewriter, loc, srcFlatVec, basePtrLLVM);
+      auto storeOp =
+          LLVM::StoreOp::create(rewriter, loc, srcFlatVec, basePtrLLVM);
       storeOp.getOperation()->setAttr(
-          "cache_control",
-          xevm::StoreCacheControlAttr::get(ctxt,
-                                          translateStoreXeGPUCacheHint(
-                                              op.getL1Hint(), op.getL3Hint())));
+          "cache_control", xevm::StoreCacheControlAttr::get(
+                               ctxt, translateStoreXeGPUCacheHint(
+                                         op.getL1Hint(), op.getL3Hint())));
       rewriter.eraseOp(op);
     }
     return success();

>From 148df78d506d26a26e5974848dce42df1c12ba8b Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Wed, 20 Aug 2025 15:29:01 +0000
Subject: [PATCH 03/10] Remove commented out code.

---
 mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 6 ------
 1 file changed, 6 deletions(-)

diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 32983152ef5bd..89f40c22e7a68 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -380,8 +380,6 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
         LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
     auto elemType = tdescTy.getElementType();
     auto elemBitSize = elemType.getIntOrFloatBitWidth();
-    // auto elemBitSizeAttr = rewriter.getIntegerAttr(rewriter.getI32Type(),
-    // elemBitSize);
     Value elemByteSize = arith::ConstantIntOp::create(
         rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
     Value surfaceW =
@@ -695,8 +693,6 @@ class DpasToXeVMPattern : public OpConversionPattern<xegpu::DpasOp> {
     };
     xevm::ElemType precATy = encodePrecision(aTy.getElementType());
     xevm::ElemType precBTy = encodePrecision(bTy.getElementType());
-    // auto precA = xevm::ElemTypeAttr::get(ctxt, precATy);
-    // auto precB = xevm::ElemTypeAttr::get(ctxt, precBTy);
     Value c = op.getAcc();
     if (!c) {
       auto elementTy = resultType.getElementType();
@@ -714,8 +710,6 @@ class DpasToXeVMPattern : public OpConversionPattern<xegpu::DpasOp> {
     auto cvecty = cast<VectorType>(c.getType());
     xevm::ElemType precCTy = encodePrecision(cvecty.getElementType());
     xevm::ElemType precDTy = encodePrecision(resultType.getElementType());
-    // auto precC = xevm::ElemTypeAttr::get(ctxt, precCTy);
-    // auto precD = xevm::ElemTypeAttr::get(ctxt, precDTy);
     VectorType cNty =
         VectorType::get(cvecty.getNumElements(), cvecty.getElementType());
     if (cvecty != cNty)

>From edd191c9939255c1639e6fbeea0ed7e6786634a8 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Wed, 20 Aug 2025 17:59:49 +0000
Subject: [PATCH 04/10] Remove dead code.

---
 mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 13 -------------
 1 file changed, 13 deletions(-)

diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 89f40c22e7a68..776380974c549 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -56,19 +56,6 @@ static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
   llvm_unreachable("Unknown XeGPU memory space.");
 }
 
-template <typename T>
-std::tuple<bool, int32_t, int32_t> checkAllLinear(SmallVector<T> denseAttr) {
-  assert(!denseAttr.empty());
-  const int32_t intercept{static_cast<int32_t>(denseAttr[0])};
-  if (denseAttr.size() < 2)
-    return {true, 0, intercept};
-  const T slope{denseAttr[1] - denseAttr[0]};
-  for (size_t i = 1; i < denseAttr.size(); ++i)
-    if (denseAttr[i] - denseAttr[i - 1] != slope)
-      return {false, 0, 0};
-  return {true, static_cast<int32_t>(slope), intercept};
-}
-
 VectorType encodeVectorTypeTo(VectorType currentVecType, Type toElemType) {
   auto elemType = currentVecType.getElementType();
   auto currentBitWidth = elemType.getIntOrFloatBitWidth();

>From 9710c9eca70796ea666f8656ab2ba2e175cc54f4 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Thu, 21 Aug 2025 22:41:37 +0000
Subject: [PATCH 05/10] Temp save.

---
 .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp    | 128 +++++++++++++-----
 .../XeGPUToXeVM/create_nd_tdesc.mlir          |   4 +-
 .../XeGPUToXeVM/loadstoreprefetch.mlir        |   4 +-
 .../XeGPUToXeVM/materializecast.mlir          |  49 +++++++
 .../Conversion/XeGPUToXeVM/update_offset.mlir |   6 +-
 5 files changed, 150 insertions(+), 41 deletions(-)
 create mode 100644 mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir

diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 776380974c549..4ff5321c1d9d2 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/Dialect/LLVMIR/XeVMDialect.h"
 
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
@@ -426,18 +427,6 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
   }
 };
 
-template <
-    typename OpType,
-    typename = std::enable_if_t<llvm::is_one_of<
-        OpType, xegpu::LoadGatherOp, xegpu::StoreScatterOp, xegpu::CreateDescOp,
-        xegpu::UpdateOffsetOp, xegpu::PrefetchOp>::value>>
-int64_t getElemByteSize(OpType op) {
-  // Get the element byte size from the tensor descriptor.
-  auto elemBitWidth =
-      op.getTensorDesc().getType().getElementType().getIntOrFloatBitWidth();
-  return elemBitWidth / 8;
-}
-
 // Add a builder that creates
 // offset * elemByteSize + baseAddr
 auto addOffset = [](ConversionPatternRewriter &rewriter, Location loc,
@@ -456,23 +445,23 @@ class CreateDescToXeVMPattern
   LogicalResult
   matchAndRewrite(xegpu::CreateDescOp op, xegpu::CreateDescOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    auto eTy = op.getTensorDescType().getElementType();
+    if (eTy.getIntOrFloatBitWidth() % 8 != 0) {
+      return rewriter.notifyMatchFailure(op,
+                                         "Expected element type bit width to be multiple of 8.");
+    }
     auto loc = op.getLoc();
+    // offsets are provided as scalar i64 by type converter.
     auto offsets = adaptor.getOffsets();
-    // Source type can be a 1D memref or ui64
-    // Using "op" instead of "adaptor" since we want to access memref type
-    // instead of LLVM struct type.
-    auto memrefTy = dyn_cast<MemRefType>(op.getSource().getType());
-    Value subGroupAddr;
-    if (memrefTy) {
-      subGroupAddr = memref::ExtractAlignedPointerAsIndexOp::create(
-          rewriter, loc, op.getSource());
-      subGroupAddr = arith::IndexCastUIOp::create(
-          rewriter, loc, rewriter.getI64Type(), subGroupAddr);
-    } else {
-      subGroupAddr = adaptor.getSource();
-    }
+    // Source type can be a 1D memref or pointer type (ui64, ui32, i64 or i32).
+    // But type converter will convert them to integer types.
+    Value addr = adaptor.getSource();
+    // ui32 or i32 are passed as i32 so they need to be casted to i64.
+    if (addr.getType() != rewriter.getI64Type())
+      addr = arith::IndexCastUIOp::create(
+          rewriter, loc, rewriter.getI64Type(), addr);
     auto laneAddr =
-        addOffset(rewriter, loc, subGroupAddr, offsets, getElemByteSize(op));
+        addOffset(rewriter, loc, addr, offsets, getElemByteSize(op));
     rewriter.replaceOp(op, laneAddr);
     return success();
   }
@@ -485,11 +474,18 @@ class UpdateOffsetToXeVMPattern
   matchAndRewrite(xegpu::UpdateOffsetOp op,
                   xegpu::UpdateOffsetOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    auto eTy = op.getTensorDescType().getElementType();
+    if (eTy.getIntOrFloatBitWidth() % 8 != 0) {
+      return rewriter.notifyMatchFailure(op,
+                                         "Expected element type bit width to be multiple of 8.");
+    }
     auto loc = op.getLoc();
-    Value newOffsetForLane =
+    // scatter descriptor is provided as scalar i64 by type converter.
+    // offsets are provided as scalar i64 by type converter.
+    Value newOffset =
         addOffset(rewriter, loc, adaptor.getTensorDesc(), adaptor.getOffsets(),
                   getElemByteSize(op));
-    rewriter.replaceOp(op, newOffsetForLane);
+    rewriter.replaceOp(op, newOffset);
     return success();
   }
 };
@@ -505,19 +501,38 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
     auto loc = op.getLoc();
     auto ctxt = rewriter.getContext();
     auto tdescTy = op.getTensorDescType();
-    auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
-        ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
+    LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
+            ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
+    if (tdescTy)
+        ptrTypeLLVM = LLVM::LLVMPointerType::get(
+            ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
     Value basePtrI64;
     if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
       basePtrI64 = adaptor.getSource();
+      if (auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
+        auto addrSpace = memRefTy.getMemorySpaceAsInt();
+        if (addrSpace != 0)
+          ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
+      }
     } else {
       basePtrI64 = adaptor.getDest();
+      if (auto memRefTy = dyn_cast<MemRefType>(op.getDest().getType())) {
+        auto addrSpace = memRefTy.getMemorySpaceAsInt();
+        if (addrSpace != 0)
+          ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
+      }
     }
+    if (basePtrI64.getType() != rewriter.getI64Type()) {
+      basePtrI64 = arith::IndexCastUIOp::create(rewriter, loc, rewriter.getI64Type(),
+                                                basePtrI64);
+    }
+    basePtrI64.dump();
     Value offsets = adaptor.getOffsets();
+    offsets.dump();
     Value mask = adaptor.getMask();
+    mask.dump();
     if (offsets) {
-      VectorType offsetsVecTy = dyn_cast<VectorType>(offsets.getType());
-      if (offsetsVecTy) {
+      if (dyn_cast<VectorType>(offsets.getType())){
         // Offset needs be scalar.
         return rewriter.notifyMatchFailure(op,
                                            "Expected offsets to be a scalar.");
@@ -526,8 +541,10 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
             addOffset(rewriter, loc, basePtrI64, offsets, getElemByteSize(op));
       }
     }
+    basePtrI64.dump();
     Value basePtrLLVM =
         LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
+    basePtrLLVM.dump();
     VectorType srcOrDstVecTy = op.getValueType();
     VectorType srcOrDstFlatVecTy = VectorType::get(
         srcOrDstVecTy.getNumElements(), srcOrDstVecTy.getElementType());
@@ -597,6 +614,10 @@ class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
         ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
     Value basePtrI64 = adaptor.getSource();
     Value offsets = adaptor.getOffsets();
+    if (basePtrI64.getType() != rewriter.getI64Type()) {
+      basePtrI64 = arith::IndexCastUIOp::create(rewriter, loc, rewriter.getI64Type(),
+                                                basePtrI64);
+    }
     if (offsets) {
       VectorType offsetsVecTy = dyn_cast<VectorType>(offsets.getType());
       if (offsetsVecTy) {
@@ -836,6 +857,26 @@ struct ConvertXeGPUToXeVMPass
       auto i32Type = IntegerType::get(&getContext(), 32);
       return VectorType::get(8, i32Type);
     });
+    typeConverter.addConversion([&](MemRefType type) -> Type {
+      // Convert MemRefType to i64 type.
+      return IntegerType::get(&getContext(), 64);
+    });
+
+    auto memrefMaterializationCast = [](OpBuilder &builder, Type type,
+                                      ValueRange inputs,
+                                      Location loc) -> Value {
+      if (inputs.size() != 1)
+        return {};
+      auto input = inputs.front();
+      if (auto memrefTy = dyn_cast<MemRefType>(input.getType())) {
+
+        Value addr = memref::ExtractAlignedPointerAsIndexOp::create(
+          builder, loc, input);
+        return arith::IndexCastUIOp::create(builder, loc, type,
+                                            addr).getResult();
+      }
+      return {};
+    };
 
     auto ui64MaterializationCast = [](OpBuilder &builder, Type type,
                                       ValueRange inputs,
@@ -847,7 +888,22 @@ struct ConvertXeGPUToXeVMPass
         Value cast =
             index::CastUOp::create(builder, loc, builder.getIndexType(), input)
                 .getResult();
-        return arith::IndexCastOp::create(builder, loc, type, cast).getResult();
+        return arith::IndexCastUIOp::create(builder, loc, type, cast).getResult();
+      }
+      return {};
+    };
+
+    auto ui32MaterializationCast = [](OpBuilder &builder, Type type,
+                                      ValueRange inputs,
+                                      Location loc) -> Value {
+      if (inputs.size() != 1)
+        return {};
+      auto input = inputs.front();
+      if (input.getType() == builder.getIntegerType(32, false)) {
+        Value cast =
+            index::CastUOp::create(builder, loc, builder.getIndexType(), input)
+                .getResult();
+        return arith::IndexCastUIOp::create(builder, loc, type, cast).getResult();
       }
       return {};
     };
@@ -864,15 +920,19 @@ struct ConvertXeGPUToXeVMPass
           Value cast =
               vector::ExtractOp::create(builder, loc, input, 0).getResult();
           if (vecTy.getElementType() == builder.getIndexType())
-            cast = arith::IndexCastOp::create(builder, loc, type, cast)
+            cast = arith::IndexCastUIOp::create(builder, loc, type, cast)
                        .getResult();
           return cast;
         }
       }
       return {};
     };
+    typeConverter.addSourceMaterialization(memrefMaterializationCast);
     typeConverter.addSourceMaterialization(ui64MaterializationCast);
+    typeConverter.addSourceMaterialization(ui32MaterializationCast);
     typeConverter.addSourceMaterialization(vector1DMaterializationCast);
+    typeConverter.addTargetMaterialization(memrefMaterializationCast);
+    typeConverter.addTargetMaterialization(ui32MaterializationCast);
     typeConverter.addTargetMaterialization(ui64MaterializationCast);
     typeConverter.addTargetMaterialization(vector1DMaterializationCast);
     ConversionTarget target(getContext());
diff --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
index 4fba920f023c4..7f5e3527a1594 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
@@ -6,8 +6,8 @@ gpu.module @create_nd_tdesc {
   // CHECK-SAME: %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index
   gpu.func @create_nd_tdesc(%src: memref<8x16xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index,
        %stride1: index, %stride2: index) kernel {
-         // CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index
-         // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
+        // CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index
+        // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
         // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
         // CHECK: %[[C0_I32:.*]] = arith.constant 0 : i32
         // CHECK: %[[C0_I32_0:.*]] = arith.constant 0 : i32
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir
index f6d023307313a..825a4d6368863 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir
@@ -5,10 +5,10 @@ gpu.module @test {
 // CHECK-SAME: %[[ARG0:.*]]: ui64
 gpu.func @load_gather_ui64_src_constant_offset(%src: ui64) {
   // CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui64 to index
-  // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
+  // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
   // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
   // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
-  // CHECK: %[[VAR3:.*]] = arith.index_cast %[[VAR2]] : index to i64
+  // CHECK: %[[VAR3:.*]] = arith.index_castui %[[VAR2]] : index to i64
   %0 = arith.constant dense<0> : vector<1xindex>
   // CHECK: %[[CST_0:.*]] = arith.constant dense<true> : vector<1xi1>
   // CHECK: %[[VAR4:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1>
diff --git a/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir b/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir
new file mode 100644
index 0000000000000..a7ae4d9b7e4d2
--- /dev/null
+++ b/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir
@@ -0,0 +1,49 @@
+// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
+
+gpu.module @materializecast {
+  // CHECK-LABEL: gpu.func @materialize_memref
+  // CHECK-SAME: %[[ARG0:.*]]: memref<128xf32>
+  gpu.func @materialize_memref(%src: memref<128xf32>) kernel {
+    // CHECK: XXX
+    %offset = arith.constant dense<0> : vector<1xindex>
+    %src_tdesc = xegpu.create_tdesc %src, %offset : memref<128xf32>, vector<1xindex>
+        -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
+    gpu.return
+  }
+  // CHECK-LABEL: gpu.func @materialize_ui64
+  // CHECK-SAME: %[[ARG0:.*]]: ui64
+  gpu.func @materialize_ui64(%src: ui64) kernel {
+    // CHECK: XXX
+    %offset = arith.constant dense<0> : vector<1xindex>
+    %src_tdesc = xegpu.create_tdesc %src, %offset : ui64, vector<1xindex>
+        -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
+    gpu.return
+  }
+  // CHECK-LABEL: gpu.func @materialize_ui32
+  // CHECK-SAME: %[[ARG0:.*]]: ui32
+  gpu.func @materialize_ui32(%src: ui32) kernel {
+    %offset = arith.constant dense<0> : vector<1xindex>
+    //%src_tdesc = xegpu.create_tdesc %src, %offset : ui32, vector<1xindex>
+    //    -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
+    gpu.return
+  }
+  // CHECK-LABEL: gpu.func @materialize_single_index_vector
+  // CHECK-SAME: %[[ARG0:.*]]: memref<128xf32>
+  gpu.func @materialize_single_index_vector(%src: memref<128xf32>) kernel {
+    // CHECK: XXX
+    %offset = arith.constant dense<0> : vector<1xindex>
+    %src_tdesc = xegpu.create_tdesc %src, %offset : memref<128xf32>, vector<1xindex>
+        -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
+    gpu.return
+  }
+  // CHECK-LABEL: gpu.func @materialize_single_elem_vector
+  // CHECK-SAME: %[[ARG0:.*]]: vector<1xi1>
+  gpu.func @materialize_single_elem_vector(%src: memref<128xf32>) kernel {
+    // CHECK: XXX
+    %mask = arith.constant dense<1>: vector<1xi1>
+    %offset = arith.constant dense<0> : vector<1xindex>
+    %0 = xegpu.load %src[%offset], %mask <{chunk_size=8, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+      : memref<128xf32>, vector<1xindex>, vector<1xi1> -> vector<1x8xf32>
+    gpu.return
+  }
+}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/update_offset.mlir b/mlir/test/Conversion/XeGPUToXeVM/update_offset.mlir
index e9d7fd4cf40a6..6e59414c62582 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/update_offset.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/update_offset.mlir
@@ -4,12 +4,12 @@ gpu.module @update_offset {
   // CHECK-LABEL: gpu.func @update_offset
   // CHECK-SAME: %[[ARG0:.*]]: memref<128xf32>
   gpu.func @update_offset(%src: memref<128xf32>) kernel {
+    // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<128xf32> -> index
+    // CHECK: %[[VAR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
     // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
     %offset = arith.constant dense<0> : vector<1xindex>
     // CHECK: %[[VAR0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
-    // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
-    // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<128xf32> -> index
-    // CHECK: %[[VAR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
+    // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
     // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
     // CHECK: %[[VAR3:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64
     // CHECK: %[[VAR4:.*]] = arith.addi %[[VAR2]], %[[VAR3]] : i64

>From aa8f765920e9d90715ed1b6be85d6622f06d7a10 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Mon, 25 Aug 2025 18:32:56 +0000
Subject: [PATCH 06/10] Adjust to latest XeGPU dialect update.

---
 .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp    | 198 ++++++++++++------
 .../XeGPUToXeVM/loadstoreprefetch.mlir        | 142 ++-----------
 .../XeGPUToXeVM/materializecast.mlir          |  39 +++-
 3 files changed, 185 insertions(+), 194 deletions(-)

diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 4ff5321c1d9d2..6cfa8ac1f8fce 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -446,9 +446,10 @@ class CreateDescToXeVMPattern
   matchAndRewrite(xegpu::CreateDescOp op, xegpu::CreateDescOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto eTy = op.getTensorDescType().getElementType();
-    if (eTy.getIntOrFloatBitWidth() % 8 != 0) {
-      return rewriter.notifyMatchFailure(op,
-                                         "Expected element type bit width to be multiple of 8.");
+    auto eBw = eTy.getIntOrFloatBitWidth();
+    if (eBw % 8 != 0) {
+      return rewriter.notifyMatchFailure(
+          op, "Expected element type bit width to be multiple of 8.");
     }
     auto loc = op.getLoc();
     // offsets are provided as scalar i64 by type converter.
@@ -458,10 +459,8 @@ class CreateDescToXeVMPattern
     Value addr = adaptor.getSource();
     // ui32 or i32 are passed as i32 so they need to be casted to i64.
     if (addr.getType() != rewriter.getI64Type())
-      addr = arith::IndexCastUIOp::create(
-          rewriter, loc, rewriter.getI64Type(), addr);
-    auto laneAddr =
-        addOffset(rewriter, loc, addr, offsets, getElemByteSize(op));
+      addr = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(), addr);
+    auto laneAddr = addOffset(rewriter, loc, addr, offsets, eBw / 8);
     rewriter.replaceOp(op, laneAddr);
     return success();
   }
@@ -475,16 +474,16 @@ class UpdateOffsetToXeVMPattern
                   xegpu::UpdateOffsetOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto eTy = op.getTensorDescType().getElementType();
-    if (eTy.getIntOrFloatBitWidth() % 8 != 0) {
-      return rewriter.notifyMatchFailure(op,
-                                         "Expected element type bit width to be multiple of 8.");
+    auto eBw = eTy.getIntOrFloatBitWidth();
+    if (eBw % 8 != 0) {
+      return rewriter.notifyMatchFailure(
+          op, "Expected element type bit width to be multiple of 8.");
     }
     auto loc = op.getLoc();
     // scatter descriptor is provided as scalar i64 by type converter.
     // offsets are provided as scalar i64 by type converter.
-    Value newOffset =
-        addOffset(rewriter, loc, adaptor.getTensorDesc(), adaptor.getOffsets(),
-                  getElemByteSize(op));
+    Value newOffset = addOffset(rewriter, loc, adaptor.getTensorDesc(),
+                                adaptor.getOffsets(), eBw / 8);
     rewriter.replaceOp(op, newOffset);
     return success();
   }
@@ -501,12 +500,35 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
     auto loc = op.getLoc();
     auto ctxt = rewriter.getContext();
     auto tdescTy = op.getTensorDescType();
-    LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
-            ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
-    if (tdescTy)
-        ptrTypeLLVM = LLVM::LLVMPointerType::get(
-            ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
     Value basePtrI64;
+    // Load result or Store valye Type can be vector or scalar.
+    Type valOrResTy;
+    if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
+      valOrResTy = op.getResult().getType();
+    } else {
+      valOrResTy = adaptor.getValue().getType();
+    }
+    VectorType valOrResVecTy = dyn_cast<VectorType>(valOrResTy);
+    bool hasScalarVal = !valOrResVecTy;
+    int64_t elemBitWidth =
+        hasScalarVal ? valOrResTy.getIntOrFloatBitWidth()
+                     : valOrResVecTy.getElementType().getIntOrFloatBitWidth();
+    // Element type must be multiple of 8 bits.
+    if (elemBitWidth % 8 != 0) {
+      return rewriter.notifyMatchFailure(
+          op, "Expected element type bit width to be multiple of 8.");
+    }
+    int64_t elemByteSize = elemBitWidth / 8;
+    // Default memory space is global.
+    LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
+        ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
+    // If tensor descriptor is available, we use its memory space.
+    if (tdescTy) {
+      ptrTypeLLVM = LLVM::LLVMPointerType::get(
+          ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
+    }
+    // Base pointer can come from source (load) or dest (store).
+    // If they are memrefs, we use their memory space.
     if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
       basePtrI64 = adaptor.getSource();
       if (auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
@@ -522,76 +544,79 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
           ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
       }
     }
+    // Base pointer is passed as i32 or i64 by adaptor, cast to i64 if needed.
     if (basePtrI64.getType() != rewriter.getI64Type()) {
-      basePtrI64 = arith::IndexCastUIOp::create(rewriter, loc, rewriter.getI64Type(),
-                                                basePtrI64);
+      basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
+                                          basePtrI64);
     }
-    basePtrI64.dump();
     Value offsets = adaptor.getOffsets();
-    offsets.dump();
     Value mask = adaptor.getMask();
-    mask.dump();
     if (offsets) {
-      if (dyn_cast<VectorType>(offsets.getType())){
-        // Offset needs be scalar.
+      if (dyn_cast<VectorType>(offsets.getType())) {
+        // Offset needs be scalar. Single element vector is converted to scalar
+        // by type converter.
         return rewriter.notifyMatchFailure(op,
                                            "Expected offsets to be a scalar.");
       } else {
+        // If offsets are provided, we add them to the base pointer.
+        // Offsets are in number of elements, we need to multiply by
+        // element byte size.
         basePtrI64 =
-            addOffset(rewriter, loc, basePtrI64, offsets, getElemByteSize(op));
+            addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize);
       }
     }
-    basePtrI64.dump();
+    // Convert base pointer (i64) to LLVM pointer type.
     Value basePtrLLVM =
         LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
-    basePtrLLVM.dump();
-    VectorType srcOrDstVecTy = op.getValueType();
-    VectorType srcOrDstFlatVecTy = VectorType::get(
-        srcOrDstVecTy.getNumElements(), srcOrDstVecTy.getElementType());
+
     Value maskForLane;
     VectorType maskVecTy = dyn_cast<VectorType>(mask.getType());
     if (maskVecTy) {
+      // Mask needs be scalar. Single element vector is converted to scalar by
+      // type converter.
       return rewriter.notifyMatchFailure(op, "Expected mask to be a scalar.");
-    } else
+    } else {
       maskForLane = mask;
+    }
     if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
-      scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, {srcOrDstVecTy},
+      scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, {valOrResTy},
                                          maskForLane, true, true);
+      // If mask is true,- then clause - load from memory and yield.
       rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
+      if (!hasScalarVal)
+        valOrResTy = VectorType::get({valOrResVecTy.getNumElements()},
+                                     valOrResVecTy.getElementType());
       Value loaded =
-          LLVM::LoadOp::create(rewriter, loc, srcOrDstFlatVecTy, basePtrLLVM);
+          LLVM::LoadOp::create(rewriter, loc, valOrResTy, basePtrLLVM);
+      // Set cache control attribute on the load operation.
       loaded.getDefiningOp()->setAttr(
           "cache_control", xevm::LoadCacheControlAttr::get(
                                ctxt, translateLoadXeGPUCacheHint(
                                          op.getL1Hint(), op.getL3Hint())));
-      if (srcOrDstVecTy != srcOrDstFlatVecTy) {
-        loaded =
-            vector::ShapeCastOp::create(rewriter, loc, srcOrDstVecTy, loaded);
-      }
       scf::YieldOp::create(rewriter, loc, ValueRange{loaded});
       rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
-      // If mask is false, we yield a vector of zeros.
-      auto eTy = srcOrDstVecTy.getElementType();
-      loaded = arith::ConstantOp::create(
-          rewriter, loc,
-          eTy.isFloat()
-              ? DenseElementsAttr::get(srcOrDstVecTy, FloatAttr::get(eTy, 0.0))
-              : DenseElementsAttr::get(srcOrDstVecTy,
-                                       IntegerAttr::get(eTy, 0)));
+      // If mask is false - else clause -yield a vector of zeros.
+      auto eTy = hasScalarVal ? valOrResTy : valOrResVecTy.getElementType();
+      TypedAttr eVal;
+      if (eTy.isFloat())
+        eVal = FloatAttr::get(eTy, 0.0);
+      else
+        eVal = IntegerAttr::get(eTy, 0);
+      if (hasScalarVal)
+        loaded = arith::ConstantOp::create(rewriter, loc, eVal);
+      else
+        loaded = arith::ConstantOp::create(
+            rewriter, loc, DenseElementsAttr::get(valOrResVecTy, eVal));
       scf::YieldOp::create(rewriter, loc, ValueRange{loaded});
       rewriter.replaceOp(op, ifOp.getResult(0));
     } else {
+      // if mask is true, perform the store.
       scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, maskForLane, false);
       auto body = ifOp.getBody();
       rewriter.setInsertionPointToStart(body);
-      VectorType valTy = op.getValue().getType();
-      Value srcFlatVec = op.getValue();
-      if (valTy != srcOrDstFlatVecTy) {
-        srcFlatVec = vector::ShapeCastOp::create(rewriter, loc,
-                                                 srcOrDstFlatVecTy, srcFlatVec);
-      }
       auto storeOp =
-          LLVM::StoreOp::create(rewriter, loc, srcFlatVec, basePtrLLVM);
+          LLVM::StoreOp::create(rewriter, loc, adaptor.getValue(), basePtrLLVM);
+      // Set cache control attribute on the store operation.
       storeOp.getOperation()->setAttr(
           "cache_control", xevm::StoreCacheControlAttr::get(
                                ctxt, translateStoreXeGPUCacheHint(
@@ -610,14 +635,13 @@ class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
     auto loc = op.getLoc();
     auto ctxt = rewriter.getContext();
     auto tdescTy = op.getTensorDescType();
-    auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
-        ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
     Value basePtrI64 = adaptor.getSource();
-    Value offsets = adaptor.getOffsets();
+    // Base pointer is passed as i32 or i64 by adaptor, cast to i64 if needed.
     if (basePtrI64.getType() != rewriter.getI64Type()) {
-      basePtrI64 = arith::IndexCastUIOp::create(rewriter, loc, rewriter.getI64Type(),
-                                                basePtrI64);
+      basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
+                                          basePtrI64);
     }
+    Value offsets = adaptor.getOffsets();
     if (offsets) {
       VectorType offsetsVecTy = dyn_cast<VectorType>(offsets.getType());
       if (offsetsVecTy) {
@@ -625,12 +649,50 @@ class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
         return rewriter.notifyMatchFailure(op,
                                            "Expected offsets to be a scalar.");
       } else {
+        int64_t elemBitWidth{0};
+        int64_t elemByteSize;
+        // Element byte size can come from three sources:
+        if (tdescTy) {
+          // If tensor descriptor is available, we use its element type to
+          // determine element byte size.
+          elemBitWidth = tdescTy.getElementType().getIntOrFloatBitWidth();
+        } else if (auto memRefTy = dyn_cast<MemRefType>(op.getSourceType())) {
+          // If memref is available, we use its element type to
+          // determine element byte size.
+          elemBitWidth = memRefTy.getElementType().getIntOrFloatBitWidth();
+        } else {
+          // Otherwise, we use the provided offset byte alignment.
+          elemByteSize = *op.getOffsetAlignByte();
+        }
+        if (elemBitWidth != 0) {
+          if (elemBitWidth % 8 != 0) {
+            return rewriter.notifyMatchFailure(
+                op, "Expected element type bit width to be multiple of 8.");
+          }
+          elemByteSize = elemBitWidth / 8;
+        }
         basePtrI64 =
-            addOffset(rewriter, loc, basePtrI64, offsets, getElemByteSize(op));
+            addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize);
       }
     }
+    // Default memory space is global.
+    LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
+        ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
+    // If tensor descriptor is available, we use its memory space.
+    if (tdescTy) {
+      ptrTypeLLVM = LLVM::LLVMPointerType::get(
+          ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
+    }
+    // If source is a memref, we use its memory space.
+    if (auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
+      auto addrSpace = memRefTy.getMemorySpaceAsInt();
+      if (addrSpace != 0)
+        ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
+    }
+    // Convert base pointer (i64) to LLVM pointer type.
     Value ptrLLVM =
         LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
+    // Create the prefetch op with cache control attribute.
     xevm::PrefetchOp::create(
         rewriter, loc, ptrLLVM,
         xevm::LoadCacheControlAttr::get(
@@ -863,17 +925,17 @@ struct ConvertXeGPUToXeVMPass
     });
 
     auto memrefMaterializationCast = [](OpBuilder &builder, Type type,
-                                      ValueRange inputs,
-                                      Location loc) -> Value {
+                                        ValueRange inputs,
+                                        Location loc) -> Value {
       if (inputs.size() != 1)
         return {};
       auto input = inputs.front();
       if (auto memrefTy = dyn_cast<MemRefType>(input.getType())) {
 
-        Value addr = memref::ExtractAlignedPointerAsIndexOp::create(
-          builder, loc, input);
-        return arith::IndexCastUIOp::create(builder, loc, type,
-                                            addr).getResult();
+        Value addr =
+            memref::ExtractAlignedPointerAsIndexOp::create(builder, loc, input);
+        return arith::IndexCastUIOp::create(builder, loc, type, addr)
+            .getResult();
       }
       return {};
     };
@@ -888,7 +950,8 @@ struct ConvertXeGPUToXeVMPass
         Value cast =
             index::CastUOp::create(builder, loc, builder.getIndexType(), input)
                 .getResult();
-        return arith::IndexCastUIOp::create(builder, loc, type, cast).getResult();
+        return arith::IndexCastUIOp::create(builder, loc, type, cast)
+            .getResult();
       }
       return {};
     };
@@ -903,7 +966,8 @@ struct ConvertXeGPUToXeVMPass
         Value cast =
             index::CastUOp::create(builder, loc, builder.getIndexType(), input)
                 .getResult();
-        return arith::IndexCastUIOp::create(builder, loc, type, cast).getResult();
+        return arith::IndexCastUIOp::create(builder, loc, type, cast)
+            .getResult();
       }
       return {};
     };
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir
index 825a4d6368863..0f67dc290689b 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir
@@ -37,15 +37,15 @@ gpu.module @test {
 // CHECK-LABEL: @load_gather_memref_src_constant_offset
 // CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>
 gpu.func @load_gather_memref_src_constant_offset(%src: memref<256xf32>) {
+  // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
+  // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64
   // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
   // CHECK: %[[VAR0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
-  // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
+  // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
   %0 = arith.constant dense<0> : vector<1xindex>
   // CHECK: %[[CST_0:.*]] = arith.constant dense<true> : vector<1xi1>
   // CHECK: %[[VAR2:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1>
   %1 = arith.constant dense<1>: vector<1xi1>
-  // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
-  // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64
   // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
   // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64
   // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR3]], %[[VAR4]] : i64
@@ -73,12 +73,12 @@ gpu.module @test {
 // CHECK-SAME: %[[ARG0:.*]]: memref<256xf16>, %[[ARG1:.*]]: vector<1xindex>
 gpu.func @load_gather_memref_src_value_offset(%src: memref<256xf16>, %offset: vector<1xindex>) {
   // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex>
-  // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
+  // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
+  // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf16> -> index
+  // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64
   // CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<1xi1>
   // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : i1 from vector<1xi1>
   %1 = arith.constant dense<1>: vector<1xi1>
-  // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf16> -> index
-  // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64
   // CHECK: %[[C2_I64:.*]] = arith.constant 2 : i64
   // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C2_I64]] : i64
   // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR3]], %[[VAR4]] : i64
@@ -99,51 +99,15 @@ gpu.func @load_gather_memref_src_value_offset(%src: memref<256xf16>, %offset: ve
 }
 // -----
 
-gpu.module @test {
-// CHECK-LABEL: @load_gather_memref_src_load_offset
-// CHECK-SAME: %[[ARG0:.*]]: memref<256xf16>, %[[ARG1:.*]]: vector<1xindex>, %[[ARG2:.*]]: vector<1xindex>
-gpu.func @load_gather_memref_src_load_offset(%src: memref<256xf16>, %offset1: vector<1xindex>, %offset2: vector<1xindex>) {
-  // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG2]][0] : index from vector<1xindex>
-  // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
-  // CHECK: %[[VAR2:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex>
-  // CHECK: %[[VAR3:.*]] = arith.index_cast %[[VAR2]] : index to i64
-  // CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<1xi1>
-  // CHECK: %[[VAR4:.*]] = vector.extract %[[CST]][0] : i1 from vector<1xi1>
-  %1 = arith.constant dense<1>: vector<1xi1>
-  // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf16> -> index
-  // CHECK: %[[VAR5:.*]] = arith.index_castui %[[INTPTR]] : index to i64
-  // CHECK: %[[C2_I64:.*]] = arith.constant 2 : i64
-  // CHECK: %[[VAR6:.*]] = arith.muli %[[VAR3]], %[[C2_I64]] : i64
-  // CHECK: %[[VAR7:.*]] = arith.addi %[[VAR5]], %[[VAR6]] : i64
-  %2 = xegpu.create_tdesc %src, %offset1 : memref<256xf16>, vector<1xindex>
-      -> !xegpu.tensor_desc<1x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8>>
-  // CHECK: %[[C2_I64_0:.*]] = arith.constant 2 : i64
-  // CHECK: %[[VAR8:.*]] = arith.muli %[[VAR1]], %[[C2_I64_0]] : i64
-  // CHECK: %[[VAR9:.*]] = arith.addi %[[VAR7]], %[[VAR8]] : i64
-  // CHECK: %[[VAR10:.*]] = llvm.inttoptr %[[VAR9]] : i64 to !llvm.ptr<1>
-  // CHECK: %[[VAR11:.*]] = scf.if %[[VAR4]] -> (vector<8xf16>) {
-  // CHECK:      %[[VAR12:.*]] = llvm.load %[[VAR10]] {cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}
-  // CHECK-SAME:     : !llvm.ptr<1> -> vector<8xf16>
-  // CHECK:      scf.yield %[[VAR12]] : vector<8xf16>
-  // CHECK:    } else {
-  // CHECK:      %[[CST_0:.*]] = arith.constant dense<0.000000e+00> : vector<8xf16>
-  // CHECK:      scf.yield %[[CST_0]] : vector<8xf16>
-  %3 = xegpu.load %2[%offset2], %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
-      : !xegpu.tensor_desc<1x8xf16, #xegpu.scatter_tdesc_attr<chunk_size = 8>>, vector<1xindex>, vector<1xi1> -> vector<8xf16>
-  gpu.return
-}
-}
-// -----
-
 gpu.module @test {
 // CHECK-LABEL: @store_scatter_ui64_src_constant_offset
 // CHECK-SAME: %[[ARG0:.*]]: ui64
 gpu.func @store_scatter_ui64_src_constant_offset(%src: ui64) {
   // CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui64 to index
-  // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
+  // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
   // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
   // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
-  // CHECK: %[[VAR3:.*]] = arith.index_cast %[[VAR2]] : index to i64
+  // CHECK: %[[VAR3:.*]] = arith.index_castui %[[VAR2]] : index to i64
   %0 = arith.constant dense<0> : vector<1xindex>
   // CHECK: %[[CST_0:.*]] = arith.constant dense<true> : vector<1xi1>
   // CHECK: %[[VAR4:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1>
@@ -170,17 +134,17 @@ gpu.module @test {
 // CHECK-LABEL: @store_scatter_memref_src_constant_offset
 // CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>
 gpu.func @store_scatter_memref_src_constant_offset(%src: memref<256xf32>) {
+  // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
+  // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64
   // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
   // CHECK: %[[VAR0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
-  // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
+  // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
   %0 = arith.constant dense<0> : vector<1xindex>
   // CHECK: %[[CST_0:.*]] = arith.constant dense<true> : vector<1xi1>
   // CHECK: %[[VAR2:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1>
   %1 = arith.constant dense<1>: vector<1xi1>
   // CHECK: %[[CST_1:.*]] = arith.constant dense<2.900390e+00> : vector<2xf16>
   %2 = arith.constant dense<2.9>: vector<2xf16>
-  // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
-  // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64
   // CHECK: %[[C2_I64:.*]] = arith.constant 2 : i64
   // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C2_I64]] : i64
   // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR3]], %[[VAR4]] : i64
@@ -202,14 +166,15 @@ gpu.module @test {
 // CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>, %[[ARG1:.*]]: vector<1xindex>
 gpu.func @store_scatter_memref_src_value_offset(%src: memref<256xf32>, %offset: vector<1xindex>) {
   // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex>
-  // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
+  // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
+  // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
+  // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64
   // CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<1xi1>
   // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : i1 from vector<1xi1>
   %1 = arith.constant dense<1>: vector<1xi1>
   // CHECK: %[[CST_0:.*]] = arith.constant dense<2.900000e+00> : vector<1xf32>
+  // CHECK: %[[VAR7:.*]] = vector.extract %[[CST_0]][0] : f32 from vector<1xf32>
   %2 = arith.constant dense<2.9>: vector<1xf32>
-  // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
-  // CHECK: %[[VAR3:.*]] = arith.index_castui %[[INTPTR]] : index to i64
   // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
   // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64
   // CHECK: %[[VAR5:.*]] = arith.addi %[[VAR3]], %[[VAR4]] : i64
@@ -217,8 +182,8 @@ gpu.func @store_scatter_memref_src_value_offset(%src: memref<256xf32>, %offset:
       -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
   // CHECK: %[[VAR6:.*]] = llvm.inttoptr %[[VAR5]] : i64 to !llvm.ptr<1>
   // CHECK: scf.if %[[VAR2]] {
-  // CHECK:      llvm.store %[[CST_0]], %[[VAR6]] {cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>}
-  // CHECK-SAME:     : vector<1xf32>, !llvm.ptr<1>
+  // CHECK:      llvm.store %[[VAR7]], %[[VAR6]] {cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>}
+  // CHECK-SAME:     : f32, !llvm.ptr<1>
   xegpu.store %2, %3, %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
       : vector<1xf32>, !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>, vector<1xi1>
   gpu.return
@@ -226,49 +191,15 @@ gpu.func @store_scatter_memref_src_value_offset(%src: memref<256xf32>, %offset:
 }
 // -----
 
-gpu.module @test {
-// CHECK-LABEL: @store_scatter_memref_src_store_offset
-// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>, %[[ARG1:.*]]: vector<1xindex>, %[[ARG2:.*]]: vector<1xindex>
-gpu.func @store_scatter_memref_src_store_offset(%src: memref<256xf32>, %offset: vector<1xindex>, %offset2: vector<1xindex>) {
-  // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG2]][0] : index from vector<1xindex>
-  // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
-  // CHECK: %[[VAR2:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex>
-  // CHECK: %[[VAR3:.*]] = arith.index_cast %[[VAR2]] : index to i64
-  // CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<1xi1>
-  // CHECK: %[[VAR4:.*]] = vector.extract %[[CST]][0] : i1 from vector<1xi1>
-  %1 = arith.constant dense<1>: vector<1xi1>
-  // CHECK: %[[CST_0:.*]] = arith.constant dense<2.900000e+00> : vector<1xf32>
-  %2 = arith.constant dense<2.9>: vector<1xf32>
-  // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
-  // CHECK: %[[VAR5:.*]] = arith.index_castui %[[INTPTR]] : index to i64
-  // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
-  // CHECK: %[[VAR6:.*]] = arith.muli %[[VAR3]], %[[C4_I64]] : i64
-  // CHECK: %[[VAR7:.*]] = arith.addi %[[VAR5]], %[[VAR6]] : i64
-  %3 = xegpu.create_tdesc %src, %offset : memref<256xf32>, vector<1xindex>
-      -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
-  // CHECK: %[[C4_I64_1:.*]] = arith.constant 4 : i64
-  // CHECK: %[[VAR8:.*]] = arith.muli %[[VAR1]], %[[C4_I64_1]] : i64
-  // CHECK: %[[VAR9:.*]] = arith.addi %[[VAR7]], %[[VAR8]] : i64
-  // CHECK: %[[VAR10:.*]] = llvm.inttoptr %[[VAR9]] : i64 to !llvm.ptr<1>
-  // CHECK: scf.if %[[VAR4]] {
-  // CHECK:      llvm.store %[[CST_0]], %[[VAR10]] {cache_control = #xevm.store_cache_control<L1wb_L2uc_L3uc>}
-  // CHECK-SAME:     : vector<1xf32>, !llvm.ptr<1>
-  xegpu.store %2, %3[%offset2], %1 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
-      : vector<1xf32>, !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>, vector<1xindex>, vector<1xi1>
-  gpu.return
-}
-}
-// -----
-
 gpu.module @test {
 // CHECK-LABEL: @prefetch_ui64_src_constant_offset
 // CHECK-SAME: %[[ARG0:.*]]: ui64
 gpu.func @prefetch_ui64_src_constant_offset(%src: ui64) {
   // CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui64 to index
-  // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
+  // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
   // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
   // CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
-  // CHECK: %[[VAR3:.*]] = arith.index_cast %[[VAR2]] : index to i64
+  // CHECK: %[[VAR3:.*]] = arith.index_castui %[[VAR2]] : index to i64
   %0 = arith.constant dense<0> : vector<1xindex>
   // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
   // CHECK: %[[VAR4:.*]] = arith.muli %[[VAR3]], %[[C4_I64]] : i64
@@ -288,12 +219,12 @@ gpu.module @test {
 // CHECK-LABEL: @prefetch_memref_src_constant_offset
 // CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>
 gpu.func @prefetch_memref_src_constant_offset(%src: memref<256xf32>) {
+  // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
+  // CHECK: %[[VAR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
   // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
   // CHECK: %[[VAR0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
-  // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
+  // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
   %0 = arith.constant dense<0> : vector<1xindex>
-  // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
-  // CHECK: %[[VAR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
   // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
   // CHECK: %[[VAR3:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64
   // CHECK: %[[VAR4:.*]] = arith.addi %[[VAR2]], %[[VAR3]] : i64
@@ -313,7 +244,7 @@ gpu.module @test {
 // CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>, %[[ARG1:.*]]: vector<1xindex>
 gpu.func @prefetch_memref_src_value_offset(%src: memref<256xf32>, %offset: vector<1xindex>) {
   // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex>
-  // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
+  // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
   // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
   // CHECK: %[[VAR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
   // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
@@ -328,30 +259,3 @@ gpu.func @prefetch_memref_src_value_offset(%src: memref<256xf32>, %offset: vecto
   gpu.return
 }
 }
-// -----
-
-gpu.module @test {
-// CHECK-LABEL: @prefetch_memref_src_prefetch_offset
-// CHECK-SAME: %[[ARG0:.*]]: memref<256xf32>, %[[ARG1:.*]]: vector<1xindex>, %[[ARG2:.*]]: vector<1xindex>
-gpu.func @prefetch_memref_src_prefetch_offset(%src: memref<256xf32>, %offset: vector<1xindex>, %offset2: vector<1xindex>) {
-  // CHECK: %[[VAR0:.*]] = vector.extract %[[ARG2]][0] : index from vector<1xindex>
-  // CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
-  // CHECK: %[[VAR2:.*]] = vector.extract %[[ARG1]][0] : index from vector<1xindex>
-  // CHECK: %[[VAR3:.*]] = arith.index_cast %[[VAR2]] : index to i64
-  // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<256xf32> -> index
-  // CHECK: %[[VAR4:.*]] = arith.index_castui %[[INTPTR]] : index to i64
-  // CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
-  // CHECK: %[[VAR5:.*]] = arith.muli %[[VAR3]], %[[C4_I64]] : i64
-  // CHECK: %[[VAR6:.*]] = arith.addi %[[VAR4]], %[[VAR5]] : i64
-  %1 = xegpu.create_tdesc %src, %offset : memref<256xf32>, vector<1xindex>
-      -> !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
-  // CHECK: %[[C4_I64_0:.*]] = arith.constant 4 : i64
-  // CHECK: %[[VAR7:.*]] = arith.muli %[[VAR1]], %[[C4_I64_0]] : i64
-  // CHECK: %[[VAR8:.*]] = arith.addi %[[VAR6]], %[[VAR7]] : i64
-  // CHECK: %[[VAR9:.*]] = llvm.inttoptr %[[VAR8]] : i64 to !llvm.ptr<1>
-  // CHECK: xevm.prefetch %[[VAR9]] <{cache_control = #xevm.load_cache_control<L1c_L2uc_L3uc>}> : (!llvm.ptr<1>)
-  xegpu.prefetch %1[%offset2] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
-      : !xegpu.tensor_desc<1x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<1xindex>
-  gpu.return
-}
-}
diff --git a/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir b/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir
index a7ae4d9b7e4d2..8db0843de4cc1 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir
@@ -1,45 +1,68 @@
-// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
+// RUN: mlir-opt -convert-xegpu-to-xevm --split-input-file %s | FileCheck %s
 
 gpu.module @materializecast {
   // CHECK-LABEL: gpu.func @materialize_memref
   // CHECK-SAME: %[[ARG0:.*]]: memref<128xf32>
   gpu.func @materialize_memref(%src: memref<128xf32>) kernel {
-    // CHECK: XXX
+    // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<128xf32> -> index
+    // CHECK: %[[CASTED:.*]] = arith.index_castui %[[INTPTR]] : index to i64
     %offset = arith.constant dense<0> : vector<1xindex>
     %src_tdesc = xegpu.create_tdesc %src, %offset : memref<128xf32>, vector<1xindex>
         -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
     gpu.return
   }
+}
+
+// -----
+gpu.module @materializecast {
   // CHECK-LABEL: gpu.func @materialize_ui64
   // CHECK-SAME: %[[ARG0:.*]]: ui64
   gpu.func @materialize_ui64(%src: ui64) kernel {
-    // CHECK: XXX
+    // CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui64 to index
+    // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
     %offset = arith.constant dense<0> : vector<1xindex>
     %src_tdesc = xegpu.create_tdesc %src, %offset : ui64, vector<1xindex>
         -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
     gpu.return
   }
+}
+
+// -----
+gpu.module @materializecast {
   // CHECK-LABEL: gpu.func @materialize_ui32
   // CHECK-SAME: %[[ARG0:.*]]: ui32
   gpu.func @materialize_ui32(%src: ui32) kernel {
+    // CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui32 to index
+    // CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i32
     %offset = arith.constant dense<0> : vector<1xindex>
-    //%src_tdesc = xegpu.create_tdesc %src, %offset : ui32, vector<1xindex>
-    //    -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
+    %src_tdesc = xegpu.create_tdesc %src, %offset : ui32, vector<1xindex>
+        -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
     gpu.return
   }
+}
+
+// -----
+gpu.module @materializecast {
   // CHECK-LABEL: gpu.func @materialize_single_index_vector
   // CHECK-SAME: %[[ARG0:.*]]: memref<128xf32>
   gpu.func @materialize_single_index_vector(%src: memref<128xf32>) kernel {
-    // CHECK: XXX
+    // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
+    // CHECK: %[[VAR1:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
+    // CHECK: %[[VAR2:.*]] = arith.index_castui %[[VAR1]] : index to i64
     %offset = arith.constant dense<0> : vector<1xindex>
     %src_tdesc = xegpu.create_tdesc %src, %offset : memref<128xf32>, vector<1xindex>
         -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
     gpu.return
   }
+}
+
+// -----
+gpu.module @materializecast {
   // CHECK-LABEL: gpu.func @materialize_single_elem_vector
-  // CHECK-SAME: %[[ARG0:.*]]: vector<1xi1>
+  // CHECK-SAME: %[[ARG0:.*]]: memref<128xf32>
   gpu.func @materialize_single_elem_vector(%src: memref<128xf32>) kernel {
-    // CHECK: XXX
+    // CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<1xi1>
+    // CHECK: %[[VAR1:.*]] = vector.extract %[[CST]][0] : i1 from vector<1xi1>
     %mask = arith.constant dense<1>: vector<1xi1>
     %offset = arith.constant dense<0> : vector<1xindex>
     %0 = xegpu.load %src[%offset], %mask <{chunk_size=8, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>

>From 434bc6b6b0b2f42879604f90bed54ec363f2f57a Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Mon, 25 Aug 2025 21:11:26 +0000
Subject: [PATCH 07/10] Temp save.

---
 .../mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h |   8 +-
 .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp    | 119 ++++++++----------
 .../XeGPUToXeVM/create_nd_tdesc.mlir          |  12 +-
 3 files changed, 62 insertions(+), 77 deletions(-)

diff --git a/mlir/include/mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h b/mlir/include/mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h
index fb23d24b0161b..ddaaae82e03be 100644
--- a/mlir/include/mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h
+++ b/mlir/include/mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h
@@ -5,8 +5,8 @@
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 //
 //===----------------------------------------------------------------------===//
-#ifndef MLIR_CONVERSION_XEGPUTOXEVM_XEGPUTOXEVMPASS_H_
-#define MLIR_CONVERSION_XEGPUTOXEVM_XEGPUTOXEVMPASS_H_
+#ifndef MLIR_CONVERSION_XEGPUTOXEVM_XEGPUTOXEVM_H_
+#define MLIR_CONVERSION_XEGPUTOXEVM_XEGPUTOXEVM_H_
 
 #include <memory>
 
@@ -20,8 +20,8 @@ class Pass;
 #include "mlir/Conversion/Passes.h.inc"
 
 void populateXeGPUToXeVMConversionPatterns(
-    mlir::RewritePatternSet &patterns, mlir::LLVMTypeConverter &typeConverter);
+    const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns);
 
 } // namespace mlir
 
-#endif // MLIR_CONVERSION_XEGPUTOXEVM_XEGPUTOXEVMPASS_H_
+#endif // MLIR_CONVERSION_XEGPUTOXEVM_XEGPUTOXEVM_H_
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 6cfa8ac1f8fce..19324f748bded 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -57,7 +57,8 @@ static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
   llvm_unreachable("Unknown XeGPU memory space.");
 }
 
-VectorType encodeVectorTypeTo(VectorType currentVecType, Type toElemType) {
+static VectorType encodeVectorTypeTo(VectorType currentVecType,
+                                     Type toElemType) {
   auto elemType = currentVecType.getElementType();
   auto currentBitWidth = elemType.getIntOrFloatBitWidth();
   auto newBitWidth = toElemType.getIntOrFloatBitWidth();
@@ -66,13 +67,11 @@ VectorType encodeVectorTypeTo(VectorType currentVecType, Type toElemType) {
   return VectorType::get(size, toElemType);
 }
 
-xevm::LoadCacheControl
+static xevm::LoadCacheControl
 translateLoadXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
                             std::optional<xegpu::CachePolicy> L3hint) {
-  auto L1hintVal =
-      L1hint.has_value() ? L1hint.value() : xegpu::CachePolicy::UNCACHED;
-  auto L3hintVal =
-      L3hint.has_value() ? L3hint.value() : xegpu::CachePolicy::UNCACHED;
+  auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED);
+  auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED);
   switch (L1hintVal) {
   case xegpu::CachePolicy::CACHED:
     if (L3hintVal == xegpu::CachePolicy::CACHED)
@@ -102,13 +101,11 @@ translateLoadXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
   }
 }
 
-xevm::StoreCacheControl
+static xevm::StoreCacheControl
 translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
                              std::optional<xegpu::CachePolicy> L3hint) {
-  auto L1hintVal =
-      L1hint.has_value() ? L1hint.value() : xegpu::CachePolicy::UNCACHED;
-  auto L3hintVal =
-      L3hint.has_value() ? L3hint.value() : xegpu::CachePolicy::UNCACHED;
+  auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED);
+  auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED);
   switch (L1hintVal) {
   case xegpu::CachePolicy::UNCACHED:
     if (L3hintVal == xegpu::CachePolicy::UNCACHED)
@@ -152,10 +149,14 @@ class CreateNdDescToXeVMPattern
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = op.getLoc();
     auto source = op.getSource();
+    // Op is lowered to a code sequence that populates payload.
+    // payload is a 8xi32 vector.
     Type payloadElemTy = rewriter.getI32Type();
     Type i64Ty = rewriter.getI64Type();
     VectorType payloadTy = VectorType::get(8, payloadElemTy);
+    // 4xi64 view is used for inserting the base pointer.
     VectorType payloadI64Ty = VectorType::get(4, i64Ty);
+    // Initialize payload to zero.
     Value payload = arith::ConstantOp::create(
         rewriter, loc,
         DenseElementsAttr::get(payloadTy, IntegerAttr::get(payloadElemTy, 0)));
@@ -166,73 +167,56 @@ class CreateNdDescToXeVMPattern
     Value offsetW;
     Value offsetH;
 
-    bool sourceIsMemref = false;
+    // Source can be a memref or a pointer (ui64, ui32, i64 or i32).
+    SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
+    SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
+    // Descriptor shape is expected to be 2D.
+    int64_t rank = mixedSizes.size();
+    if (rank != 2)
+      return rewriter.notifyMatchFailure(op, "Expected 2D shape.");
     auto sourceTy = source.getType();
-    int64_t rank;
-    if (isa<MemRefType>(sourceTy)) {
-      sourceIsMemref = true;
+    auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
+    // If source is a memref, we need to extract the aligned pointer as index.
+    // pointer type is passed as i32 or i64 by type converter.
+    if (sourceMemrefTy) {
       baseAddr =
           memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source);
-      auto sourceMemrefTy = cast<MemRefType>(sourceTy);
       if (!sourceMemrefTy.hasStaticShape()) {
         op.emitError() << "Expected static memref shape.";
         return failure();
       }
-      rank = sourceMemrefTy.getRank();
-      if (rank != 2) {
-        op.emitError() << "Expected a 2D memref.";
-        return failure();
-      }
-    } else if (sourceTy == rewriter.getIntegerType(64, false)) {
-      rank = op.getMixedSizes().size();
     } else {
-      op.emitError() << "Expected source to be a 2D memref or ui64.";
-      return failure();
+      baseAddr = adaptor.getSource();
     }
-    auto createOffset = [&](unsigned idx) -> Value {
-      Value val;
-      OpFoldResult ofr = op.getMixedOffsets()[idx];
-      if (auto v = llvm::dyn_cast_if_present<Value>(ofr)) {
-        val = arith::IndexCastOp::create(rewriter, loc, i64Ty, v);
-        val = arith::TruncIOp::create(rewriter, loc, payloadElemTy, val);
-      } else {
-        int32_t off = llvm::cast<IntegerAttr>(cast<Attribute>(ofr)).getInt();
-        val = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, off);
-      }
+    // utility for creating offset values from op fold result.
+    auto createOffset = [&](SmallVector<OpFoldResult> &ofrVec,
+                            unsigned idx) -> Value {
+      Value val = getValueOrCreateConstantIntOp(rewriter, loc, ofrVec[idx]);
+      val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val);
       return val;
     };
-    auto offsets = op.getMixedOffsets();
-    if (offsets.size() == 2) {
-      offsetW = createOffset(rank - 1);
-      offsetH = createOffset(rank - 2);
-    } else {
+    // Offsets can be either 2D or not provided (0 is used).
+    if (mixedOffsets.size() == 2) {
+      offsetW = createOffset(mixedOffsets, rank - 1);
+      offsetH = createOffset(mixedOffsets, rank - 2);
+    } else if (mixedOffsets.size() == 0) {
       offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
       offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
+    } else {
+      return rewriter.notifyMatchFailure(op,
+                                         "Expected 2D offsets or no offsets.");
     }
-    auto createShape = [&](unsigned idx) -> Value {
-      Value val;
-      OpFoldResult ofr = op.getMixedSizes()[idx];
-      if (auto v = llvm::dyn_cast_if_present<Value>(ofr)) {
-        val = arith::IndexCastOp::create(rewriter, loc, i64Ty, v);
-        val = arith::TruncIOp::create(rewriter, loc, payloadElemTy, val);
-      } else {
-        int32_t off = llvm::cast<IntegerAttr>(cast<Attribute>(ofr)).getInt();
-        val = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, off);
-      }
-      return val;
-    };
-    if (sourceIsMemref) {
-      auto sourceMemrefTy = cast<MemRefType>(sourceTy);
-      baseShapeW = arith::ConstantIntOp::create(
-          rewriter, loc, payloadElemTy, sourceMemrefTy.getDimSize(rank - 1));
-      baseShapeH = arith::ConstantIntOp::create(
-          rewriter, loc, payloadElemTy, sourceMemrefTy.getDimSize(rank - 2));
+    // Get shape values from op fold results.
+    baseShapeW = createOffset(mixedSizes, rank - 1);
+    baseShapeH = createOffset(mixedSizes, rank - 2);
+    if (sourceMemrefTy) {
+      // cast index to i64.
       baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr);
-    } else {
-      baseShapeW = createShape(rank - 1);
-      baseShapeH = createShape(rank - 2);
-      baseAddr = adaptor.getSource();
+    } else if (baseAddr.getType() != i64Ty) {
+      // pointer type may be i32. Cast to i64 if needed.
+      baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr);
     }
+    // Populate payload.
     Value payLoadAsI64 =
         vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
     payLoadAsI64 =
@@ -429,9 +413,9 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
 
 // Add a builder that creates
 // offset * elemByteSize + baseAddr
-auto addOffset = [](ConversionPatternRewriter &rewriter, Location loc,
-                    Value baseAddr, Value offset,
-                    int64_t elemByteSize) -> Value {
+static auto addOffset = [](ConversionPatternRewriter &rewriter, Location loc,
+                           Value baseAddr, Value offset,
+                           int64_t elemByteSize) -> Value {
   Value byteSize = arith::ConstantIntOp::create(
       rewriter, loc, rewriter.getI64Type(), elemByteSize);
   Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize);
@@ -701,6 +685,7 @@ class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
     return success();
   }
 };
+
 class FenceToXeVMPattern : public OpConversionPattern<xegpu::FenceOp> {
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
@@ -1007,7 +992,7 @@ struct ConvertXeGPUToXeVMPass
     target.addIllegalDialect<xegpu::XeGPUDialect>();
 
     RewritePatternSet patterns(&getContext());
-    populateXeGPUToXeVMConversionPatterns(patterns, typeConverter);
+    populateXeGPUToXeVMConversionPatterns(typeConverter, patterns);
     scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter,
                                                          patterns, target);
     if (failed(applyPartialConversion(getOperation(), target,
@@ -1021,7 +1006,7 @@ struct ConvertXeGPUToXeVMPass
 // Pattern Population
 //===----------------------------------------------------------------------===//
 void mlir::populateXeGPUToXeVMConversionPatterns(
-    RewritePatternSet &patterns, LLVMTypeConverter &typeConverter) {
+    const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
   patterns.add<CreateNdDescToXeVMPattern, UpdateNdOffsetToXeVMPattern,
                LoadStorePrefetchNdToXeVMPattern<xegpu::LoadNdOp>,
                LoadStorePrefetchNdToXeVMPattern<xegpu::StoreNdOp>,
diff --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
index 7f5e3527a1594..ba7ece8ccbebe 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
@@ -11,10 +11,8 @@ gpu.module @create_nd_tdesc {
         // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
         // CHECK: %[[C0_I32:.*]] = arith.constant 0 : i32
         // CHECK: %[[C0_I32_0:.*]] = arith.constant 0 : i32
-        // CHECK: %[[VAR2:.*]] = arith.index_cast %[[ARG3]] : index to i64
-        // CHECK: %[[VAR3:.*]] = arith.trunci %[[VAR2]] : i64 to i32
-        // CHECK: %[[VAR4:.*]] = arith.index_cast %[[ARG2]] : index to i64
-        // CHECK: %[[VAR5:.*]] = arith.trunci %[[VAR4]] : i64 to i32
+        // CHECK: %[[VAR3:.*]] = arith.index_cast %[[ARG3]] : index to i32
+        // CHECK: %[[VAR5:.*]] = arith.index_cast %[[ARG2]] : index to i32
         // CHECK: %[[VAR6:.*]] = vector.bitcast %[[CST]] : vector<8xi32> to vector<4xi64>
         // CHECK: %[[VAR7:.*]] = vector.insert %[[VAR1]], %[[VAR6]] [0] : i64 into vector<4xi64>
         // CHECK: %[[VAR8:.*]] = vector.bitcast %[[VAR7]] : vector<4xi64> to vector<8xi32>
@@ -32,8 +30,10 @@ gpu.module @create_nd_tdesc {
         // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<8x16xf32> -> index
         // CHECK: %[[C0_I32_2:.*]] = arith.constant 0 : i32
         // CHECK: %[[C0_I32_3:.*]] = arith.constant 0 : i32
-        // CHECK: %[[C16_I32:.*]] = arith.constant 16 : i32
-        // CHECK: %[[C8_I32:.*]] = arith.constant 8 : i32
+        // CHECK: %[[C16_I64:.*]] = arith.constant 16 : i64
+        // CHECK: %[[C16_I32:.*]] = arith.trunci %c16_i64 : i64 to i32
+        // CHECK: %[[C8_I64:.*]] = arith.constant 8 : i64
+        // CHECK: %[[C8_I32:.*]] = arith.trunci %c8_i64 : i64 to i32
         // CHECK: %[[VAR13:.*]] = arith.index_castui %[[INTPTR]] : index to i64
         // CHECK: %[[VAR14:.*]] = vector.bitcast %[[CST_1]] : vector<8xi32> to vector<4xi64>
         // CHECK: %[[VAR15:.*]] = vector.insert %[[VAR13]], %[[VAR14]] [0] : i64 into vector<4xi64>

>From d32a4445a6e8219618c887c912772c6af9f1350a Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Mon, 25 Aug 2025 22:11:43 +0000
Subject: [PATCH 08/10] Update update_nd_tdesc.

---
 .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp    | 39 +++++++++----------
 1 file changed, 19 insertions(+), 20 deletions(-)

diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 19324f748bded..12af0af70177b 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -248,27 +248,26 @@ class UpdateNdOffsetToXeVMPattern
                   xegpu::UpdateNdOffsetOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = op.getLoc();
-    auto offsets = op.getOffsets();
+    auto mixedOffsets = op.getMixedOffsets();
+    if (mixedOffsets.size() != 2)
+      return rewriter.notifyMatchFailure(op, "Expected 2D offsets.");
     auto tdesc = adaptor.getTensorDesc();
-    for (size_t offsetDim = 0; offsetDim < offsets.size(); offsetDim++) {
-      auto offset = offsets[offsetDim];
-      if (auto cst =
-              dyn_cast_if_present<arith::ConstantOp>(offset.getDefiningOp()))
-        if (auto attr = dyn_cast_if_present<IntegerAttr>(cst.getValue());
-            attr && !attr.getInt())
-          continue;
-      const int offsetPos =
-          static_cast<int>(offsetDim ? NdDescI32Layout::TensorOffsetW
-                                     : NdDescI32Layout::TensorOffsetH);
-      auto oldOffset =
-          vector::ExtractOp::create(rewriter, loc, tdesc, offsetPos);
-      offset = arith::IndexCastUIOp::create(rewriter, loc,
-                                            rewriter.getI32Type(), offset);
-      auto newOffset = arith::AddIOp::create(rewriter, loc, oldOffset, offset);
-      tdesc =
-          vector::InsertOp::create(rewriter, loc, newOffset, tdesc, offsetPos);
-    }
-    rewriter.replaceOp(op, tdesc);
+    // utility for updating payload offset values from op fold result.
+    auto updateOffset = [&](unsigned idx, int payloadPos) -> Value {
+      Value offset =
+          getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[idx]);
+      offset = getValueOrCreateCastToIndexLike(rewriter, loc,
+                                               rewriter.getI32Type(), offset);
+      Value oldOffset =
+          vector::ExtractOp::create(rewriter, loc, tdesc, payloadPos);
+      Value newOffset = arith::AddIOp::create(rewriter, loc, oldOffset, offset);
+      return vector::InsertOp::create(rewriter, loc, newOffset, tdesc,
+                                      payloadPos);
+    };
+    auto val =
+        updateOffset(0, static_cast<int>(NdDescI32Layout::TensorOffsetH));
+    val = updateOffset(1, static_cast<int>(NdDescI32Layout::TensorOffsetW));
+    rewriter.replaceOp(op, val);
     return success();
   }
 };

>From 6a1c12625fa9624c2b675976fb96a4ef63a8e2d1 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Mon, 25 Aug 2025 22:29:06 +0000
Subject: [PATCH 09/10] Temp save.

---
 .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp    | 160 ++++++++++--------
 .../Conversion/XeGPUToXeVM/loadstore_nd.mlir  |  12 +-
 .../XeGPUToXeVM/materializecast.mlir          |   2 +-
 .../Conversion/XeGPUToXeVM/prefetch_nd.mlir   |   6 +-
 4 files changed, 99 insertions(+), 81 deletions(-)

diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 12af0af70177b..963ab29695b1f 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -39,12 +39,13 @@ using namespace mlir;
 
 namespace {
 
-enum class NdDescI32Layout : uint32_t {
-  BasePtr = 0,
-  BaseShapeW = 2,
-  BaseShapeH = 3,
-  TensorOffsetW = 4,
-  TensorOffsetH = 5
+// Offsets to individual fields of the 8xi32 layout nd tensor descriptor.
+enum class NdTdescOffset : uint32_t {
+  BasePtr = 0,       // Base pointer (i64)
+  BaseShapeW = 2,    // Base shape width (i32)
+  BaseShapeH = 3,    // Base shape height (i32)
+  TensorOffsetW = 4, // Tensor offset W (i32)
+  TensorOffsetH = 5  // Tensor offset H (i32)
 };
 
 static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
@@ -57,6 +58,7 @@ static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
   llvm_unreachable("Unknown XeGPU memory space.");
 }
 
+// Get same bitwidth flat vector type of new element type.
 static VectorType encodeVectorTypeTo(VectorType currentVecType,
                                      Type toElemType) {
   auto elemType = currentVecType.getElementType();
@@ -221,20 +223,20 @@ class CreateNdDescToXeVMPattern
         vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
     payLoadAsI64 =
         vector::InsertOp::create(rewriter, loc, baseAddr, payLoadAsI64,
-                                 static_cast<int>(NdDescI32Layout::BasePtr));
+                                 static_cast<int>(NdTdescOffset::BasePtr));
     payload = vector::BitCastOp::create(rewriter, loc, payloadTy, payLoadAsI64);
     payload =
         vector::InsertOp::create(rewriter, loc, baseShapeW, payload,
-                                 static_cast<int>(NdDescI32Layout::BaseShapeW));
+                                 static_cast<int>(NdTdescOffset::BaseShapeW));
     payload =
         vector::InsertOp::create(rewriter, loc, baseShapeH, payload,
-                                 static_cast<int>(NdDescI32Layout::BaseShapeH));
+                                 static_cast<int>(NdTdescOffset::BaseShapeH));
     payload = vector::InsertOp::create(
         rewriter, loc, offsetW, payload,
-        static_cast<int>(NdDescI32Layout::TensorOffsetW));
+        static_cast<int>(NdTdescOffset::TensorOffsetW));
     payload = vector::InsertOp::create(
         rewriter, loc, offsetH, payload,
-        static_cast<int>(NdDescI32Layout::TensorOffsetH));
+        static_cast<int>(NdTdescOffset::TensorOffsetH));
     rewriter.replaceOp(op, payload);
     return success();
   }
@@ -249,6 +251,7 @@ class UpdateNdOffsetToXeVMPattern
                   ConversionPatternRewriter &rewriter) const override {
     auto loc = op.getLoc();
     auto mixedOffsets = op.getMixedOffsets();
+    // Only 2D offsets are supported for now.
     if (mixedOffsets.size() != 2)
       return rewriter.notifyMatchFailure(op, "Expected 2D offsets.");
     auto tdesc = adaptor.getTensorDesc();
@@ -264,9 +267,9 @@ class UpdateNdOffsetToXeVMPattern
       return vector::InsertOp::create(rewriter, loc, newOffset, tdesc,
                                       payloadPos);
     };
-    auto val =
-        updateOffset(0, static_cast<int>(NdDescI32Layout::TensorOffsetH));
-    val = updateOffset(1, static_cast<int>(NdDescI32Layout::TensorOffsetW));
+    // Update offsets in the payload.
+    auto val = updateOffset(0, static_cast<int>(NdTdescOffset::TensorOffsetH));
+    val = updateOffset(1, static_cast<int>(NdTdescOffset::TensorOffsetW));
     rewriter.replaceOp(op, val);
     return success();
   }
@@ -293,62 +296,46 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
     VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
     Value payLoadAsI64 =
         vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc);
-    Value basePtr =
-        vector::ExtractOp::create(rewriter, loc, payLoadAsI64,
-                                  static_cast<int>(NdDescI32Layout::BasePtr));
+    Value basePtr = vector::ExtractOp::create(
+        rewriter, loc, payLoadAsI64, static_cast<int>(NdTdescOffset::BasePtr));
     Value baseShapeW = vector::ExtractOp::create(
-        rewriter, loc, tdesc, static_cast<int>(NdDescI32Layout::BaseShapeW));
+        rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW));
     Value baseShapeH = vector::ExtractOp::create(
-        rewriter, loc, tdesc, static_cast<int>(NdDescI32Layout::BaseShapeH));
-    // Offsets can come from three sources:
-    // 1. Constant offsets, which are provided by the op.
-    // 2. Offsets as operands, which are provided by the op.
-    // 3. Offsets extracted from the tensor descriptor.
+        rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH));
+    // Offsets provided in two ways:
+    // 1. Offsets are extracted from the tensor descriptor.
+    // 2. (Mixed) offsets which are provided by the op.
     Value offsetW;
     Value offsetH;
-    auto cOffsets = op.getConstOffsets();
-    auto offsets = op.getOffsets();
-    if (cOffsets) {
-      offsetW = arith::ConstantIntOp::create(
-          rewriter, loc, rewriter.getI32Type(), (*cOffsets)[0]);
-      offsetH = arith::ConstantIntOp::create(
-          rewriter, loc, rewriter.getI32Type(), (*cOffsets)[1]);
-    } else if (offsets.size() != 0) {
-      // offsets are provided as operands
-      if (offsets[0].getType() != rewriter.getI32Type()) {
-        if (offsets[0].getType() != rewriter.getIndexType()) {
-          return rewriter.notifyMatchFailure(
-              op, "Expected offsets to be of type i32 or index.");
-        }
-        offsetW = arith::IndexCastUIOp::create(
-            rewriter, loc, rewriter.getI32Type(), offsets[0]);
-      } else {
-        offsetW = offsets[0];
-      }
-      if (offsets[1].getType() != rewriter.getI32Type()) {
-        if (offsets[1].getType() != rewriter.getIndexType()) {
-          return rewriter.notifyMatchFailure(
-              op, "Expected offsets to be of type i32 or index.");
-        }
-        offsetH = arith::IndexCastUIOp::create(
-            rewriter, loc, rewriter.getI32Type(), offsets[1]);
-      } else {
-        offsetH = offsets[1];
-      }
+    auto mixedOffsets = op.getMixedOffsets();
+    int64_t opOffsetsSize = mixedOffsets.size();
+    if (opOffsetsSize != 0 && opOffsetsSize != 2) {
+      return rewriter.notifyMatchFailure(op,
+                                         "Expected 2D offsets or no offsets.");
+    }
+    if (opOffsetsSize) {
+      // If mixed offsets are provided by the op convert them to i32.
+      offsetW = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]);
+      offsetW = getValueOrCreateCastToIndexLike(rewriter, loc,
+                                                rewriter.getI32Type(), offsetW);
+      offsetH = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
+      offsetH = getValueOrCreateCastToIndexLike(rewriter, loc,
+                                                rewriter.getI32Type(), offsetH);
     } else {
       // If offsets are not available, we need to extract them from the tensor
       // descriptor.
       offsetW = vector::ExtractOp::create(
-          rewriter, loc, tdesc,
-          static_cast<int>(NdDescI32Layout::TensorOffsetW));
+          rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::TensorOffsetW));
       offsetH = vector::ExtractOp::create(
-          rewriter, loc, tdesc,
-          static_cast<int>(NdDescI32Layout::TensorOffsetH));
+          rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::TensorOffsetH));
     }
+    // Get address space from tensor descriptor memory space.
     auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
         ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
+    // Convert base pointer (i64) to LLVM pointer type.
     Value basePtrLLVM =
         LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
+    // Compute element byte size and surface width in bytes.
     auto elemType = tdescTy.getElementType();
     auto elemBitSize = elemType.getIntOrFloatBitWidth();
     Value elemByteSize = arith::ConstantIntOp::create(
@@ -356,23 +343,27 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
     Value surfaceW =
         arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
 
+    // Get tile sizes and vblocks from the tensor descriptor type.
     auto tileW = tdescTy.getDimSize(1);
     auto tileH = tdescTy.getDimSize(0);
     int32_t vblocks = tdescTy.getArrayLength();
     if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
-      VectorType srcVecTy = cast<VectorType>(op.getValue().getType());
+      VectorType srcVecTy = dyn_cast<VectorType>(adaptor.getValue().getType());
+      if (!srcVecTy) {
+        return rewriter.notifyMatchFailure(
+            op, "Expected store value to be a vector type.");
+      }
       auto storeCacheControl =
           translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
-      VectorType srcFlatVecTy =
-          VectorType::get(srcVecTy.getNumElements(), srcVecTy.getElementType());
-      Value srcFlatVec = op.getValue();
-      srcFlatVecTy = encodeVectorTypeTo(srcFlatVecTy,
-                                        rewriter.getIntegerType(elemBitSize));
-      srcFlatVec =
-          vector::BitCastOp::create(rewriter, loc, srcFlatVecTy, srcFlatVec);
+      Value src = adaptor.getValue();
+      // Get flat vector type of integer type with matching element bit size.
+      VectorType newSrcVecTy =
+          encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
+      if (srcVecTy != newSrcVecTy)
+        src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
       xevm::BlockStore2dOp::create(
           rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
-          offsetH, elemBitSize, tileW, tileH, srcFlatVec,
+          offsetH, elemBitSize, tileW, tileH, src,
           xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
       rewriter.eraseOp(op);
     } else {
@@ -412,15 +403,14 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
 
 // Add a builder that creates
 // offset * elemByteSize + baseAddr
-static auto addOffset = [](ConversionPatternRewriter &rewriter, Location loc,
-                           Value baseAddr, Value offset,
-                           int64_t elemByteSize) -> Value {
+static Value addOffset(ConversionPatternRewriter &rewriter, Location loc,
+                       Value baseAddr, Value offset, int64_t elemByteSize) {
   Value byteSize = arith::ConstantIntOp::create(
       rewriter, loc, rewriter.getI64Type(), elemByteSize);
   Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize);
   Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset);
   return newAddr;
-};
+}
 
 class CreateDescToXeVMPattern
     : public OpConversionPattern<xegpu::CreateDescOp> {
@@ -908,6 +898,10 @@ struct ConvertXeGPUToXeVMPass
       return IntegerType::get(&getContext(), 64);
     });
 
+    // LLVM type converter puts unrealized casts for the following cases:
+    // add materialization casts to handle them.
+
+    // Materialization to convert memref to i64
     auto memrefMaterializationCast = [](OpBuilder &builder, Type type,
                                         ValueRange inputs,
                                         Location loc) -> Value {
@@ -924,6 +918,7 @@ struct ConvertXeGPUToXeVMPass
       return {};
     };
 
+    // Materialization to convert ui64 to i64
     auto ui64MaterializationCast = [](OpBuilder &builder, Type type,
                                       ValueRange inputs,
                                       Location loc) -> Value {
@@ -940,6 +935,7 @@ struct ConvertXeGPUToXeVMPass
       return {};
     };
 
+    // Materialization to convert ui32 to i32
     auto ui32MaterializationCast = [](OpBuilder &builder, Type type,
                                       ValueRange inputs,
                                       Location loc) -> Value {
@@ -956,9 +952,13 @@ struct ConvertXeGPUToXeVMPass
       return {};
     };
 
-    auto vector1DMaterializationCast = [](OpBuilder &builder, Type type,
-                                          ValueRange inputs,
-                                          Location loc) -> Value {
+    // Materialization to convert
+    //   - single element 1D vector to scalar
+    //   - bitcast vector of same rank
+    //   - shape vector of different rank but same element type
+    auto vectorMaterializationCast = [](OpBuilder &builder, Type type,
+                                        ValueRange inputs,
+                                        Location loc) -> Value {
       if (inputs.size() != 1)
         return {};
       auto input = inputs.front();
@@ -971,6 +971,18 @@ struct ConvertXeGPUToXeVMPass
             cast = arith::IndexCastUIOp::create(builder, loc, type, cast)
                        .getResult();
           return cast;
+        } else if (auto targetVecTy = dyn_cast<VectorType>(type)) {
+          // If the target type is a vector of same rank,
+          //   bitcast to the target type.
+          if (targetVecTy.getRank() == vecTy.getRank())
+            return vector::BitCastOp::create(builder, loc, targetVecTy, input)
+                .getResult();
+          else if (targetVecTy.getElementType() == vecTy.getElementType()) {
+            // If the target type is a vector of different rank but same element
+            // type, reshape to the target type.
+            return vector::ShapeCastOp::create(builder, loc, targetVecTy, input)
+                .getResult();
+          }
         }
       }
       return {};
@@ -978,11 +990,11 @@ struct ConvertXeGPUToXeVMPass
     typeConverter.addSourceMaterialization(memrefMaterializationCast);
     typeConverter.addSourceMaterialization(ui64MaterializationCast);
     typeConverter.addSourceMaterialization(ui32MaterializationCast);
-    typeConverter.addSourceMaterialization(vector1DMaterializationCast);
+    typeConverter.addSourceMaterialization(vectorMaterializationCast);
     typeConverter.addTargetMaterialization(memrefMaterializationCast);
     typeConverter.addTargetMaterialization(ui32MaterializationCast);
     typeConverter.addTargetMaterialization(ui64MaterializationCast);
-    typeConverter.addTargetMaterialization(vector1DMaterializationCast);
+    typeConverter.addTargetMaterialization(vectorMaterializationCast);
     ConversionTarget target(getContext());
     target.addLegalDialect<xevm::XeVMDialect, LLVM::LLVMDialect,
                            vector::VectorDialect, arith::ArithDialect,
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
index c692da632d458..4c6bbf25b4728 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_nd.mlir
@@ -20,8 +20,10 @@ gpu.module @load_store_check {
         //CHECK: %[[LD_INTPTR:.*]] = vector.extract %[[LD_DESC_I64]][0] : i64 from vector<4xi64>
         //CHECK: %[[LD_BASE_W:.*]] = vector.extract %[[LD_DESC]][2] : i32 from vector<8xi32>
         //CHECK: %[[LD_BASE_H:.*]] = vector.extract %[[LD_DESC]][3] : i32 from vector<8xi32>
-        //CHECK: %[[LD_TILE_W:.*]] = arith.constant 0 : i32
-        //CHECK: %[[LD_TILE_H:.*]] = arith.constant 0 : i32
+        //CHECK: %[[LD_TILE_W64:.*]] = arith.constant 0 : i64
+        //CHECK: %[[LD_TILE_W:.*]] = arith.trunci %[[LD_TILE_W64]] : i64 to i32
+        //CHECK: %[[LD_TILE_H64:.*]] = arith.constant 0 : i64
+        //CHECK: %[[LD_TILE_H:.*]] = arith.trunci %[[LD_TILE_H64]] : i64 to i32
         //CHECK: %[[LD_LLVMPTR:.*]] = llvm.inttoptr %[[LD_INTPTR]] : i64 to !llvm.ptr<1>
         //CHECK: %[[LD_SIZEOF_F32:.*]] = arith.constant 4 : i32
         //CHECK: %[[LD_BASE_ROW_IN_BYTES:.*]] = arith.muli %[[LD_BASE_W]], %[[LD_SIZEOF_F32]] : i32
@@ -54,8 +56,10 @@ gpu.module @load_store_check {
         //CHECK: %[[INTPTR:.*]] = vector.extract %[[DESC_I64]][0] : i64 from vector<4xi64>
         //CHECK: %[[BASE_W:.*]] = vector.extract %[[DESC]][2] : i32 from vector<8xi32>
         //CHECK: %[[BASE_H:.*]] = vector.extract %[[DESC]][3] : i32 from vector<8xi32>
-        //CHECK: %[[TILE_W:.*]] = arith.constant 0 : i32
-        //CHECK: %[[TILE_H:.*]] = arith.constant 0 : i32
+        //CHECK: %[[TILE_W64:.*]] = arith.constant 0 : i64
+        //CHECK: %[[TILE_W:.*]] = arith.trunci %[[TILE_W64]] : i64 to i32
+        //CHECK: %[[TILE_H64:.*]] = arith.constant 0 : i64
+        //CHECK: %[[TILE_H:.*]] = arith.trunci %[[TILE_H64]] : i64 to i32
         //CHECK: %[[LLVMPTR:.*]] = llvm.inttoptr %[[INTPTR]] : i64 to !llvm.ptr<1>
         //CHECK: %[[SIZEOF_F32:.*]] = arith.constant 4 : i32
         //CHECK: %[[BASE_ROW_IN_BYTES:.*]] = arith.muli %[[BASE_W]], %[[SIZEOF_F32]] : i32
diff --git a/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir b/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir
index 8db0843de4cc1..2445c4b341657 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/materializecast.mlir
@@ -66,7 +66,7 @@ gpu.module @materializecast {
     %mask = arith.constant dense<1>: vector<1xi1>
     %offset = arith.constant dense<0> : vector<1xindex>
     %0 = xegpu.load %src[%offset], %mask <{chunk_size=8, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
-      : memref<128xf32>, vector<1xindex>, vector<1xi1> -> vector<1x8xf32>
+      : memref<128xf32>, vector<1xindex>, vector<1xi1> -> vector<8xf32>
     gpu.return
   }
 }
diff --git a/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir
index 8513b4f9857fb..873478aed57e3 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/prefetch_nd.mlir
@@ -20,8 +20,10 @@ gpu.module @fence_check {
         //CHECK: %[[PREF_INTPTR:.*]] = vector.extract %[[LD_DESC_I64]][0] : i64 from vector<4xi64>
         //CHECK: %[[PREF_BASE_W:.*]] = vector.extract %[[LD_DESC]][2] : i32 from vector<8xi32>
         //CHECK: %[[PREF_BASE_H:.*]] = vector.extract %[[LD_DESC]][3] : i32 from vector<8xi32>
-        //CHECK: %[[PREF_TILE_W:.*]] = arith.constant 0 : i32
-        //CHECK: %[[PREF_TILE_H:.*]] = arith.constant 0 : i32
+        //CHECK: %[[PREF_TILE_W64:.*]] = arith.constant 0 : i64
+        //CHECK: %[[PREF_TILE_W:.*]] = arith.trunci %[[PREF_TILE_W64]] : i64 to i32
+        //CHECK: %[[PREF_TILE_H64:.*]] = arith.constant 0 : i64
+        //CHECK: %[[PREF_TILE_H:.*]] = arith.trunci %[[PREF_TILE_H64]] : i64 to i32
         //CHECK: %[[PREF_LLVMPTR:.*]] = llvm.inttoptr %[[PREF_INTPTR]] : i64 to !llvm.ptr<1>
         //CHECK: %[[PREF_SIZEOF_F32:.*]] = arith.constant 4 : i32
         //CHECK: %[[PREF_BASE_ROW_IN_BYTES:.*]] = arith.muli %[[PREF_BASE_W]], %[[PREF_SIZEOF_F32]] : i32

>From cd81113781e11f0c4c488d9398ba2fb43154fb16 Mon Sep 17 00:00:00 2001
From: "Lee, Sang Ik" <sang.ik.lee at intel.com>
Date: Tue, 26 Aug 2025 00:49:47 +0000
Subject: [PATCH 10/10] Address comments.

---
 .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp    | 23 +++++++++----------
 1 file changed, 11 insertions(+), 12 deletions(-)

diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 963ab29695b1f..6cd50a38a21a4 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -152,7 +152,7 @@ class CreateNdDescToXeVMPattern
     auto loc = op.getLoc();
     auto source = op.getSource();
     // Op is lowered to a code sequence that populates payload.
-    // payload is a 8xi32 vector.
+    // Payload is a 8xi32 vector.
     Type payloadElemTy = rewriter.getI32Type();
     Type i64Ty = rewriter.getI64Type();
     VectorType payloadTy = VectorType::get(8, payloadElemTy);
@@ -179,7 +179,7 @@ class CreateNdDescToXeVMPattern
     auto sourceTy = source.getType();
     auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
     // If source is a memref, we need to extract the aligned pointer as index.
-    // pointer type is passed as i32 or i64 by type converter.
+    // Pointer type is passed as i32 or i64 by type converter.
     if (sourceMemrefTy) {
       baseAddr =
           memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source);
@@ -190,7 +190,7 @@ class CreateNdDescToXeVMPattern
     } else {
       baseAddr = adaptor.getSource();
     }
-    // utility for creating offset values from op fold result.
+    // Utility for creating offset values from op fold result.
     auto createOffset = [&](SmallVector<OpFoldResult> &ofrVec,
                             unsigned idx) -> Value {
       Value val = getValueOrCreateConstantIntOp(rewriter, loc, ofrVec[idx]);
@@ -212,10 +212,10 @@ class CreateNdDescToXeVMPattern
     baseShapeW = createOffset(mixedSizes, rank - 1);
     baseShapeH = createOffset(mixedSizes, rank - 2);
     if (sourceMemrefTy) {
-      // cast index to i64.
+      // Cast index to i64.
       baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr);
     } else if (baseAddr.getType() != i64Ty) {
-      // pointer type may be i32. Cast to i64 if needed.
+      // Pointer type may be i32. Cast to i64 if needed.
       baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr);
     }
     // Populate payload.
@@ -255,7 +255,7 @@ class UpdateNdOffsetToXeVMPattern
     if (mixedOffsets.size() != 2)
       return rewriter.notifyMatchFailure(op, "Expected 2D offsets.");
     auto tdesc = adaptor.getTensorDesc();
-    // utility for updating payload offset values from op fold result.
+    // Utility for updating payload offset values from op fold result.
     auto updateOffset = [&](unsigned idx, int payloadPos) -> Value {
       Value offset =
           getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[idx]);
@@ -425,7 +425,7 @@ class CreateDescToXeVMPattern
           op, "Expected element type bit width to be multiple of 8.");
     }
     auto loc = op.getLoc();
-    // offsets are provided as scalar i64 by type converter.
+    // Offsets are provided as scalar i64 by type converter.
     auto offsets = adaptor.getOffsets();
     // Source type can be a 1D memref or pointer type (ui64, ui32, i64 or i32).
     // But type converter will convert them to integer types.
@@ -453,8 +453,8 @@ class UpdateOffsetToXeVMPattern
           op, "Expected element type bit width to be multiple of 8.");
     }
     auto loc = op.getLoc();
-    // scatter descriptor is provided as scalar i64 by type converter.
-    // offsets are provided as scalar i64 by type converter.
+    // Scatter descriptor is provided as scalar i64 by type converter.
+    // Offsets are provided as scalar i64 by type converter.
     Value newOffset = addOffset(rewriter, loc, adaptor.getTensorDesc(),
                                 adaptor.getOffsets(), eBw / 8);
     rewriter.replaceOp(op, newOffset);
@@ -583,7 +583,7 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
       scf::YieldOp::create(rewriter, loc, ValueRange{loaded});
       rewriter.replaceOp(op, ifOp.getResult(0));
     } else {
-      // if mask is true, perform the store.
+      // If mask is true, perform the store.
       scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, maskForLane, false);
       auto body = ifOp.getBody();
       rewriter.setInsertionPointToStart(body);
@@ -758,7 +758,7 @@ class DpasToXeVMPattern : public OpConversionPattern<xegpu::DpasOp> {
         VectorType::get(cvecty.getNumElements(), cvecty.getElementType());
     if (cvecty != cNty)
       c = vector::ShapeCastOp::create(rewriter, loc, cNty, c);
-    // below are uArch dependent values, should move away from hardcoding
+    // Below are uArch dependent values, should move away from hardcoding
     constexpr int32_t systolicDepth{8};
     constexpr int32_t executionSize{16};
     Value dpasRes = xevm::MMAOp::create(
@@ -818,7 +818,6 @@ matchSimpleAtomicOp(arith::AtomicRMWKind arithKind) {
   default:
     return std::nullopt;
   }
-  llvm_unreachable("Invalid AtomicRMWKind");
 }
 
 class AtomicRMWToXeVMPattern : public OpConversionPattern<xegpu::AtomicRMWOp> {



More information about the Mlir-commits mailing list