[Mlir-commits] [mlir] [MLIR][XeGPU] Add lowering from transfer_read/transfer_write to load_gather/store_scatter (PR #152429)

Jianhui Li llvmlistbot at llvm.org
Wed Aug 13 15:02:02 PDT 2025


https://github.com/Jianhui-Li updated https://github.com/llvm/llvm-project/pull/152429

>From 06a43e93243705bc86c0b3fda14650646fcd1d45 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 31 Jul 2025 18:17:47 +0000
Subject: [PATCH 01/17] add initial lowering

---
 .../VectorToXeGPU/VectorToXeGPU.cpp           | 169 ++++++++++++++++++
 1 file changed, 169 insertions(+)

diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 80107554144cf..ac3234d050a74 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -21,6 +21,9 @@
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/TypeSwitch.h"
 
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
+
 #include <algorithm>
 #include <optional>
 
@@ -155,6 +158,166 @@ createNdDescriptor(PatternRewriter &rewriter, Location loc,
   return ndDesc;
 }
 
+std::optional<std::string> getXeGPUChipStr(Operation *op) {
+  auto gpuModuleOp = op->getParentOfType<mlir::gpu::GPUModuleOp>();
+  if (gpuModuleOp) {
+    auto targetAttrs = gpuModuleOp.getTargets();
+    if (targetAttrs) {
+      for (auto &attr : *targetAttrs) {
+        auto xevmAttr = llvm::dyn_cast<mlir::xevm::XeVMTargetAttr>(attr);
+        if (xevmAttr)
+          return xevmAttr.getChip().str();
+      }
+    }
+  }
+  return std::nullopt;
+}
+
+// This function lowers vector.transfer_read to XeGPU load operation.
+  // Example:
+  //   %0 = vector.transfer_read %expand_shape[%block_id_y, %c0, %c0, %c0, %c0], 
+  //               %cst {in_bounds = [true, true, true, true]}>} : 
+  //               memref<8x4x2x6x32xbf16>, vector<4x2x6x32xbf16>
+  // 
+  //   %6 = vector.step: vector<4xindex> 
+  //   %7 = vector.step: vector<2xindex> 
+  //   %8 = vector.step: vector<6xindex> 
+  //   %9 = vector.step: vector<32xindex> 
+  //   %10 = arith.mul %6, 384
+  //   %11 = arith.mul %7, 192
+  //   %12 = arith.mul %8, 32
+  //   %13 = arith.mul %9, 1
+  //   %14 = vector.shape_cast %10: vector<4xindex> -> vector<4x1x1x1xbf16>
+  //   %15 = vector.shape_cast %11: vector<2xindex> -> vector<1x2x1x1xbf16>
+  //   %16 = vector.shape_cast %12: vector<6xindex> -> vector<1x1x6x1xbf16>
+  //   %17 = vector.shape_cast %13: vector<32xindex> -> vector<1x1x1x32xbf16>
+  //   %18 = vector.broadcast %14: vector<4x1x1x1xbf16> -> vector<4x2x6x32xindex>  
+  //   %19 = vector.broadcast %15: vector<1x2x1x1xbf16> -> vector<4x2x6x32xindex>  
+  //   %20 = vector.broadcast %16: vector<1x1x6x1xbf16> -> vector<4x2x6x32xindex>  
+  //   %21 = vector.broadcast %17: vector<1x1x1x32xbf16> -> vector<4x2x6x32xindex>  
+  //   %22 = arith.add %18, %19
+  //   %23 = arith.add %20, %21
+  //   %local_offsets = arith.add %22, %23
+  //   %orig_offset = %block_id_y * 4x2x6x32 // consider using affine map
+  //   %offsets =  orig_offset + local_offsets
+  //   %expand_shape1 = memref.view %expand_shape: memref<8x4x2x6x32xbf16> -> memref<?bf16>
+  //   %vec = xegpu.load_gather %expand_shape1[%offsets]:memref<?xbf16>,
+  //                           vector<4x2x6x32xindex> -> vector<4x2x6x32xbf16>
+
+LogicalResult lowerTransferReadToLoadOp(vector::TransferReadOp readOp,
+                                PatternRewriter &rewriter) {
+    Location loc = readOp.getLoc();
+  // Get the source memref and vector type
+  auto memrefType = dyn_cast<MemRefType>(readOp.getShapedType());
+  if (!memrefType) {
+    return rewriter.notifyMatchFailure(readOp, "Expected memref source");
+  }
+  
+  VectorType vectorType = readOp.getVectorType();
+  ArrayRef<int64_t> vectorShape = vectorType.getShape();
+  Type elementType = vectorType.getElementType();
+  
+  // Get memref strides for offset calculation
+  SmallVector<int64_t> strides;
+  int64_t offset;
+  if (failed(memrefType.getStridesAndOffset(strides, offset))) {
+    return rewriter.notifyMatchFailure(readOp, "Failed to get memref strides");
+  }
+  
+  // Step 1: Create vector.step operations for each dimension
+  SmallVector<Value> stepVectors;
+  for (int64_t dim : vectorShape) {
+    auto stepType = VectorType::get({dim}, rewriter.getIndexType());
+    auto stepOp = rewriter.create<vector::StepOp>(loc, stepType);
+    stepVectors.push_back(stepOp);
+  }
+  
+  // Step 2: Multiply step vectors by corresponding strides
+  SmallVector<Value> strideMultiplied;
+  size_t memrefRank = memrefType.getRank();
+  size_t vectorRank = vectorShape.size();
+  
+  for (size_t i = 0; i < vectorRank; ++i) {
+    // Map vector dimension to memref dimension (innermost dimensions)
+    size_t memrefDim = memrefRank - vectorRank + i;
+    int64_t stride = strides[memrefDim];
+    
+    Value strideConstant = rewriter.create<arith::ConstantIndexOp>(loc, stride);
+    
+    // Create element-wise multiplication
+    auto mulType = llvm::cast<VectorType>(stepVectors[i].getType());
+    auto mulOp = rewriter.create<arith::MulIOp>(loc, stepVectors[i], 
+                                               rewriter.create<vector::SplatOp>(loc, strideConstant, mulType));
+    strideMultiplied.push_back(mulOp);
+  }
+  
+  // Step 3: Shape cast each multiplied vector to add singleton dimensions
+  SmallVector<Value> shapeCasted;
+  for (size_t i = 0; i < vectorRank; ++i) {
+    SmallVector<int64_t> newShape(vectorRank, 1);
+    newShape[i] = vectorShape[i];
+    
+    auto newType = VectorType::get(newShape, rewriter.getIndexType());
+    auto castOp = rewriter.create<vector::ShapeCastOp>(loc, newType, strideMultiplied[i]);
+    shapeCasted.push_back(castOp);
+  }
+  
+  // Step 4: Broadcast each shape-casted vector to full vector shape
+  SmallVector<Value> broadcasted;
+  auto fullIndexVectorType = VectorType::get(vectorShape, rewriter.getIndexType());
+  
+  for (Value shapeCastVal : shapeCasted) {
+    auto broadcastOp = rewriter.create<vector::BroadcastOp>(loc, fullIndexVectorType, shapeCastVal);
+    broadcasted.push_back(broadcastOp);
+  }
+  
+  // Step 5: Add all broadcasted vectors together to compute local offsets
+  Value localOffsets = broadcasted[0];
+  for (size_t i = 1; i < broadcasted.size(); ++i) {
+    localOffsets = rewriter.create<arith::AddIOp>(loc, localOffsets, broadcasted[i]);
+  }
+  
+  // Step 6: Compute base offset from transfer read indices
+  Value baseOffset = nullptr;
+  auto indices = readOp.getIndices();
+  
+  if (!indices.empty()) {
+    // Calculate linearized base offset: sum(index[i] * stride[i])
+    baseOffset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    
+    for (size_t i = 0; i < indices.size(); ++i) {
+      Value strideVal = rewriter.create<arith::ConstantIndexOp>(loc, strides[i]);
+      Value offsetContrib = rewriter.create<arith::MulIOp>(loc, indices[i], strideVal);
+      baseOffset = rewriter.create<arith::AddIOp>(loc, baseOffset, offsetContrib);
+    }
+    
+    // Broadcast base offset to match vector shape
+    Value splatBase = rewriter.create<vector::SplatOp>(loc, baseOffset, fullIndexVectorType);
+    localOffsets = rewriter.create<arith::AddIOp>(loc, splatBase, localOffsets);
+  }
+  
+  // Step 7: Create flattened memref view
+  auto flatMemrefType = MemRefType::get({ShapedType::kDynamic}, elementType);
+  auto viewOp = rewriter.create<memref::ViewOp>(loc, flatMemrefType, 
+                                               readOp.getBase(), 
+                                               rewriter.create<arith::ConstantIndexOp>(loc, 0),
+                                               ValueRange{});
+  
+  // Step 8: Create XeGPU gather load operation
+  auto gatherOp = rewriter.create<xegpu::LoadGatherOp>(loc, vectorType, 
+                                                       viewOp, localOffsets,
+                                                       /*mask=*/Value{},
+                                                       /*l1_hint=*/nullptr,
+                                                       /*l2_hint=*/nullptr,
+                                                       /*l3_hint=*/nullptr);
+  
+  // Replace the original transfer read with gather load
+  rewriter.replaceOp(readOp, gatherOp.getResult());
+  
+  return success();
+}
+
+
 struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
   using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
 
@@ -164,6 +327,12 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
 
     if (failed(transferPreconditions(rewriter, readOp)))
       return failure();
+    
+    auto chip = getXeGPUChipStr(readOp);
+    if ( chip != "pvc" && chip != "bmg") {
+      // calling another function that lower TransferReadOp to regular Loadop
+      return lowerTransferReadToLoadOp(readOp, rewriter);
+    }
 
     bool isOutOfBounds = readOp.hasOutOfBoundsDim();
     if (isOutOfBounds && !isZeroConstant(readOp.getPadding()))

>From 933207716e622c05afb720eb2b24d311a1543dc8 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 1 Aug 2025 02:09:19 +0000
Subject: [PATCH 02/17] add chipstr check and lowering

---
 .../mlir/Dialect/XeGPU/Utils/XeGPUUtils.h     |   2 +
 .../Conversion/VectorToXeGPU/CMakeLists.txt   |   1 +
 .../VectorToXeGPU/VectorToXeGPU.cpp           | 291 ++++++++++++------
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        |   6 +-
 mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt   |   3 +
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp   |  17 +
 .../VectorToXeGPU/transfer-read-to-xegpu.mlir | 175 +++++++----
 .../transfer-write-to-xegpu.mlir              |  78 +++--
 8 files changed, 390 insertions(+), 183 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index 488f358ff3802..67f74a4fa2e0e 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -123,6 +123,8 @@ Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc,
 void doSCFStructuralTypeConversionWithTensorType(Operation *op,
                                                  TypeConverter converter);
 
+std::optional<std::string> getXeGPUChipStr(Operation *op);
+
 } // namespace xegpu
 
 } // namespace mlir
diff --git a/mlir/lib/Conversion/VectorToXeGPU/CMakeLists.txt b/mlir/lib/Conversion/VectorToXeGPU/CMakeLists.txt
index 567083da00239..e9ad67c52820d 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToXeGPU/CMakeLists.txt
@@ -13,4 +13,5 @@ add_mlir_conversion_library(MLIRVectorToXeGPU
   MLIRTransforms
   MLIRVectorDialect
   MLIRXeGPUDialect
+  MLIRXeGPUUtils
   )
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index ac3234d050a74..c22a709c45d46 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -17,13 +17,11 @@
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/TypeSwitch.h"
 
-#include "mlir/Dialect/GPU/IR/GPUDialect.h"
-#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
-
 #include <algorithm>
 #include <optional>
 
@@ -158,21 +156,6 @@ createNdDescriptor(PatternRewriter &rewriter, Location loc,
   return ndDesc;
 }
 
-std::optional<std::string> getXeGPUChipStr(Operation *op) {
-  auto gpuModuleOp = op->getParentOfType<mlir::gpu::GPUModuleOp>();
-  if (gpuModuleOp) {
-    auto targetAttrs = gpuModuleOp.getTargets();
-    if (targetAttrs) {
-      for (auto &attr : *targetAttrs) {
-        auto xevmAttr = llvm::dyn_cast<mlir::xevm::XeVMTargetAttr>(attr);
-        if (xevmAttr)
-          return xevmAttr.getChip().str();
-      }
-    }
-  }
-  return std::nullopt;
-}
-
 // This function lowers vector.transfer_read to XeGPU load operation.
   // Example:
   //   %0 = vector.transfer_read %expand_shape[%block_id_y, %c0, %c0, %c0, %c0], 
@@ -200,30 +183,21 @@ std::optional<std::string> getXeGPUChipStr(Operation *op) {
   //   %local_offsets = arith.add %22, %23
   //   %orig_offset = %block_id_y * 4x2x6x32 // consider using affine map
   //   %offsets =  orig_offset + local_offsets
-  //   %expand_shape1 = memref.view %expand_shape: memref<8x4x2x6x32xbf16> -> memref<?bf16>
-  //   %vec = xegpu.load_gather %expand_shape1[%offsets]:memref<?xbf16>,
-  //                           vector<4x2x6x32xindex> -> vector<4x2x6x32xbf16>
 
-LogicalResult lowerTransferReadToLoadOp(vector::TransferReadOp readOp,
-                                PatternRewriter &rewriter) {
-    Location loc = readOp.getLoc();
-  // Get the source memref and vector type
-  auto memrefType = dyn_cast<MemRefType>(readOp.getShapedType());
-  if (!memrefType) {
-    return rewriter.notifyMatchFailure(readOp, "Expected memref source");
-  }
-  
-  VectorType vectorType = readOp.getVectorType();
+//   %expand_shape1 = memref.collapseshape %expand_shape:
+//   memref<8x4x2x6x32xbf16> -> memref<?bf16>
+
+//   %vec = xegpu.load_gather %expand_shape1[%offsets]:memref<?xbf16>,
+//                           vector<4x2x6x32xindex> -> vector<4x2x6x32xbf16>
+
+// Compute localOffsets for load_gather and store_scatter
+static Value computeGatherOffsets(vector::TransferReadOp readOp,
+                                  PatternRewriter &rewriter,
+                                  ArrayRef<int64_t> strides,
+                                  VectorType vectorType) {
+  Location loc = readOp.getLoc();
   ArrayRef<int64_t> vectorShape = vectorType.getShape();
-  Type elementType = vectorType.getElementType();
-  
-  // Get memref strides for offset calculation
-  SmallVector<int64_t> strides;
-  int64_t offset;
-  if (failed(memrefType.getStridesAndOffset(strides, offset))) {
-    return rewriter.notifyMatchFailure(readOp, "Failed to get memref strides");
-  }
-  
+
   // Step 1: Create vector.step operations for each dimension
   SmallVector<Value> stepVectors;
   for (int64_t dim : vectorShape) {
@@ -231,92 +205,224 @@ LogicalResult lowerTransferReadToLoadOp(vector::TransferReadOp readOp,
     auto stepOp = rewriter.create<vector::StepOp>(loc, stepType);
     stepVectors.push_back(stepOp);
   }
-  
+
   // Step 2: Multiply step vectors by corresponding strides
-  SmallVector<Value> strideMultiplied;
-  size_t memrefRank = memrefType.getRank();
+  size_t memrefRank = strides.size();
   size_t vectorRank = vectorShape.size();
-  
+  SmallVector<Value> strideMultiplied;
   for (size_t i = 0; i < vectorRank; ++i) {
-    // Map vector dimension to memref dimension (innermost dimensions)
     size_t memrefDim = memrefRank - vectorRank + i;
     int64_t stride = strides[memrefDim];
-    
     Value strideConstant = rewriter.create<arith::ConstantIndexOp>(loc, stride);
-    
-    // Create element-wise multiplication
     auto mulType = llvm::cast<VectorType>(stepVectors[i].getType());
-    auto mulOp = rewriter.create<arith::MulIOp>(loc, stepVectors[i], 
-                                               rewriter.create<vector::SplatOp>(loc, strideConstant, mulType));
+    auto mulOp = rewriter.create<arith::MulIOp>(
+        loc, stepVectors[i],
+        rewriter.create<vector::SplatOp>(loc, strideConstant, mulType));
     strideMultiplied.push_back(mulOp);
   }
-  
+
   // Step 3: Shape cast each multiplied vector to add singleton dimensions
   SmallVector<Value> shapeCasted;
   for (size_t i = 0; i < vectorRank; ++i) {
     SmallVector<int64_t> newShape(vectorRank, 1);
     newShape[i] = vectorShape[i];
-    
     auto newType = VectorType::get(newShape, rewriter.getIndexType());
-    auto castOp = rewriter.create<vector::ShapeCastOp>(loc, newType, strideMultiplied[i]);
+    auto castOp =
+        rewriter.create<vector::ShapeCastOp>(loc, newType, strideMultiplied[i]);
     shapeCasted.push_back(castOp);
   }
-  
+
   // Step 4: Broadcast each shape-casted vector to full vector shape
   SmallVector<Value> broadcasted;
-  auto fullIndexVectorType = VectorType::get(vectorShape, rewriter.getIndexType());
-  
+  auto fullIndexVectorType =
+      VectorType::get(vectorShape, rewriter.getIndexType());
   for (Value shapeCastVal : shapeCasted) {
-    auto broadcastOp = rewriter.create<vector::BroadcastOp>(loc, fullIndexVectorType, shapeCastVal);
+    auto broadcastOp = rewriter.create<vector::BroadcastOp>(
+        loc, fullIndexVectorType, shapeCastVal);
     broadcasted.push_back(broadcastOp);
   }
-  
+
   // Step 5: Add all broadcasted vectors together to compute local offsets
   Value localOffsets = broadcasted[0];
   for (size_t i = 1; i < broadcasted.size(); ++i) {
-    localOffsets = rewriter.create<arith::AddIOp>(loc, localOffsets, broadcasted[i]);
+    localOffsets =
+        rewriter.create<arith::AddIOp>(loc, localOffsets, broadcasted[i]);
   }
-  
+
   // Step 6: Compute base offset from transfer read indices
   Value baseOffset = nullptr;
   auto indices = readOp.getIndices();
-  
   if (!indices.empty()) {
-    // Calculate linearized base offset: sum(index[i] * stride[i])
     baseOffset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-    
     for (size_t i = 0; i < indices.size(); ++i) {
-      Value strideVal = rewriter.create<arith::ConstantIndexOp>(loc, strides[i]);
-      Value offsetContrib = rewriter.create<arith::MulIOp>(loc, indices[i], strideVal);
-      baseOffset = rewriter.create<arith::AddIOp>(loc, baseOffset, offsetContrib);
+      Value strideVal =
+          rewriter.create<arith::ConstantIndexOp>(loc, strides[i]);
+      Value offsetContrib =
+          rewriter.create<arith::MulIOp>(loc, indices[i], strideVal);
+      baseOffset =
+          rewriter.create<arith::AddIOp>(loc, baseOffset, offsetContrib);
     }
-    
     // Broadcast base offset to match vector shape
-    Value splatBase = rewriter.create<vector::SplatOp>(loc, baseOffset, fullIndexVectorType);
+    Value splatBase =
+        rewriter.create<vector::SplatOp>(loc, baseOffset, fullIndexVectorType);
     localOffsets = rewriter.create<arith::AddIOp>(loc, splatBase, localOffsets);
   }
-  
-  // Step 7: Create flattened memref view
-  auto flatMemrefType = MemRefType::get({ShapedType::kDynamic}, elementType);
-  auto viewOp = rewriter.create<memref::ViewOp>(loc, flatMemrefType, 
-                                               readOp.getBase(), 
-                                               rewriter.create<arith::ConstantIndexOp>(loc, 0),
-                                               ValueRange{});
-  
-  // Step 8: Create XeGPU gather load operation
-  auto gatherOp = rewriter.create<xegpu::LoadGatherOp>(loc, vectorType, 
-                                                       viewOp, localOffsets,
-                                                       /*mask=*/Value{},
-                                                       /*l1_hint=*/nullptr,
-                                                       /*l2_hint=*/nullptr,
-                                                       /*l3_hint=*/nullptr);
-  
-  // Replace the original transfer read with gather load
+
+  return localOffsets;
+}
+
+// Collapse memref shape to 1D
+static Value collapseMemrefTo1D(vector::TransferReadOp readOp,
+                                PatternRewriter &rewriter,
+                                MemRefType memrefType, Type elementType) {
+  Location loc = readOp.getLoc();
+  int64_t totalElements = 1;
+  bool hasDynamicDim = false;
+  for (int64_t dim : memrefType.getShape()) {
+    if (dim == ShapedType::kDynamic) {
+      hasDynamicDim = true;
+      break;
+    }
+    totalElements *= dim;
+  }
+
+  MemRefType flatMemrefType;
+  if (hasDynamicDim) {
+    flatMemrefType = MemRefType::get({ShapedType::kDynamic}, elementType);
+  } else {
+    flatMemrefType = MemRefType::get({totalElements}, elementType);
+  }
+
+  SmallVector<ReassociationIndices> reassociation;
+  ReassociationIndices allDims;
+  for (int i = 0; i < memrefType.getRank(); ++i) {
+    allDims.push_back(i);
+  }
+  reassociation.push_back(allDims);
+
+  auto collapseOp = rewriter.create<memref::CollapseShapeOp>(
+      loc, flatMemrefType, readOp.getBase(), reassociation);
+  return collapseOp;
+}
+
+// Create XeGPU gather load operation
+static LogicalResult createLoadGather(vector::TransferReadOp readOp,
+                                      PatternRewriter &rewriter,
+                                      Value flatMemref, Value localOffsets,
+                                      VectorType vectorType) {
+  Location loc = readOp.getLoc();
+  ArrayRef<int64_t> vectorShape = vectorType.getShape();
+  Value mask = rewriter.create<vector::ConstantMaskOp>(
+      loc, VectorType::get(vectorShape, rewriter.getI1Type()), vectorShape);
+  auto gatherOp = rewriter.create<xegpu::LoadGatherOp>(
+      loc, vectorType, flatMemref, localOffsets, mask,
+      /*chunk_size=*/IntegerAttr{},
+      /*l1_hint=*/xegpu::CachePolicyAttr{},
+      /*l2_hint=*/xegpu::CachePolicyAttr{},
+      /*l3_hint=*/xegpu::CachePolicyAttr{});
   rewriter.replaceOp(readOp, gatherOp.getResult());
-  
   return success();
 }
 
+// Create XeGPU store scatter operation
+static LogicalResult createStoreScatter(vector::TransferWriteOp writeOp,
+                                        PatternRewriter &rewriter,
+                                        Value flatMemref, Value localOffsets,
+                                        Value value, VectorType vectorType) {
+  Location loc = writeOp.getLoc();
+  ArrayRef<int64_t> vectorShape = vectorType.getShape();
+  Value mask = rewriter.create<vector::ConstantMaskOp>(
+      loc, VectorType::get(vectorShape, rewriter.getI1Type()), vectorShape);
+  rewriter.create<xegpu::StoreScatterOp>(loc, value, flatMemref, localOffsets,
+                                         mask,
+                                         /*chunk_size=*/IntegerAttr{},
+                                         /*l1_hint=*/xegpu::CachePolicyAttr{},
+                                         /*l2_hint=*/xegpu::CachePolicyAttr{},
+                                         /*l3_hint=*/xegpu::CachePolicyAttr{});
+  rewriter.eraseOp(writeOp);
+  return success();
+}
+
+LogicalResult lowerTransferReadToLoadOp(vector::TransferReadOp readOp,
+                                        PatternRewriter &rewriter) {
+
+  auto memrefType = dyn_cast<MemRefType>(readOp.getShapedType());
+  if (!memrefType)
+    return rewriter.notifyMatchFailure(readOp, "Expected memref source");
+
+  VectorType vectorType = readOp.getVectorType();
+  Type elementType = vectorType.getElementType();
+
+  SmallVector<int64_t> strides;
+  int64_t offset;
+  if (failed(memrefType.getStridesAndOffset(strides, offset)))
+    return rewriter.notifyMatchFailure(readOp, "Failed to get memref strides");
+
+  Value localOffsets =
+      computeGatherOffsets(readOp, rewriter, strides, vectorType);
+  Value flatMemref =
+      collapseMemrefTo1D(readOp, rewriter, memrefType, elementType);
+  return createLoadGather(readOp, rewriter, flatMemref, localOffsets,
+                          vectorType);
+}
+
+LogicalResult lowerTransferWriteToStoreOp(vector::TransferWriteOp writeOp,
+                                          PatternRewriter &rewriter) {
+
+  auto memrefType = dyn_cast<MemRefType>(writeOp.getShapedType());
+  if (!memrefType)
+    return rewriter.notifyMatchFailure(writeOp, "Expected memref source");
+
+  VectorType vectorType = writeOp.getVectorType();
+  Type elementType = vectorType.getElementType();
+
+  SmallVector<int64_t> strides;
+  int64_t offset;
+  if (failed(memrefType.getStridesAndOffset(strides, offset)))
+    return rewriter.notifyMatchFailure(writeOp, "Failed to get memref strides");
+
+  // Compute localOffsets for store_scatter
+  Value localOffsets =
+      computeGatherOffsets(cast<vector::TransferReadOp>(writeOp.getOperation()),
+                           rewriter, strides, vectorType);
+
+  Value flatMemref =
+      collapseMemrefTo1D(cast<vector::TransferReadOp>(writeOp.getOperation()),
+                         rewriter, memrefType, elementType);
+
+  return createStoreScatter(writeOp, rewriter, flatMemref, localOffsets,
+                            writeOp.getVector(), vectorType);
+}
+
+static LogicalResult
+extraCheckForScatteredLoadStore(vector::TransferReadOp readOp,
+                                PatternRewriter &rewriter) {
+  // 1. it must be inbound access by checking in_bounds attributes, like
+  // {in_bounds = [false, true]}
+  if (readOp.hasOutOfBoundsDim())
+    return rewriter.notifyMatchFailure(
+        readOp, "Out-of-bounds access is not supported for this chip");
+  // 2. if the memref has static shape, its lower rank must exactly match with
+  // vector shape.
+  if (auto memrefType = dyn_cast<MemRefType>(readOp.getShapedType())) {
+    if (memrefType.hasStaticShape()) {
+      ArrayRef<int64_t> memrefShape = memrefType.getShape();
+      ArrayRef<int64_t> vectorShape = readOp.getVectorType().getShape();
+      size_t memrefRank = memrefShape.size();
+      size_t vectorRank = vectorShape.size();
+      if (vectorRank > memrefRank)
+        return rewriter.notifyMatchFailure(
+            readOp, "Vector rank cannot exceed memref rank");
+      // Compare the last vectorRank dimensions of memref with vector shape
+      for (size_t i = 0; i < vectorRank; ++i) {
+        if (memrefShape[memrefRank - vectorRank + i] != vectorShape[i])
+          return rewriter.notifyMatchFailure(
+              readOp, "Memref lower dimensions must match vector shape");
+      }
+    }
+  }
+  return success();
+}
 
 struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
   using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
@@ -327,9 +433,12 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
 
     if (failed(transferPreconditions(rewriter, readOp)))
       return failure();
-    
-    auto chip = getXeGPUChipStr(readOp);
+
+    auto chip = xegpu::getXeGPUChipStr(readOp);
     if ( chip != "pvc" && chip != "bmg") {
+      // perform additional checks -
+      if (failed(extraCheckForScatteredLoadStore(readOp, rewriter)))
+        return failure();
       // calling another function that lower TransferReadOp to regular Loadop
       return lowerTransferReadToLoadOp(readOp, rewriter);
     }
@@ -390,6 +499,12 @@ struct TransferWriteLowering
     if (failed(transferPreconditions(rewriter, writeOp)))
       return failure();
 
+    auto chip = xegpu::getXeGPUChipStr(writeOp);
+    if (chip != "pvc" && chip != "bmg") {
+      // calling another function that lower TransferWriteOp to regular StoreOp
+      return lowerTransferWriteToStoreOp(writeOp, rewriter);
+    }
+
     AffineMap map = writeOp.getPermutationMap();
     if (!map.isMinorIdentity())
       return rewriter.notifyMatchFailure(writeOp, "Expects identity map");
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 33450f3fa229e..570689bc0969e 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -123,7 +123,7 @@ isValidGatherScatterBufferParams(Type maskTy, VectorType valueTy,
 
   // a valid shape for SIMT case
   if (valueTy.getRank() == 1) {
-    if (valueTy.getNumElements() != chunkSize)
+    if (valueTy.getNumElements() % chunkSize != 0)
       return emitError() << "value elements must match chunk size " << chunkSize
                          << " for SIMT code.";
     return success();
@@ -674,7 +674,7 @@ LogicalResult PrefetchOp::verify() {
   auto tdescTy = getTensorDescType();
 
   if (tdescTy && !tdescTy.isScattered())
-    return emitOpError("Expects a scattered TensorDesc.\n");
+    return emitOpError("Expects a scattered TensorDesc.");
 
   if (!tdescTy && getRankOf(getSource()) > 1)
     return emitOpError(
@@ -755,7 +755,7 @@ LogicalResult StoreScatterOp::verify() {
   auto valueTy = getValueType();
 
   if (tdescTy && !tdescTy.isScattered())
-    return emitOpError("Expects a scattered TensorDesc.\n");
+    return emitOpError("Expects a scattered TensorDesc.");
 
   if (!tdescTy && getRankOf(getDest()) > 1)
     return emitOpError(
diff --git a/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt
index 98e84a4420722..23c26875476b6 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt
@@ -7,5 +7,8 @@ add_mlir_dialect_library(MLIRXeGPUUtils
   LINK_LIBS PUBLIC
   MLIRIR
   MLIRSCFTransforms
+  MLIRGPUDialect
+  MLIRLLVMDialect
+  MLIRXeVMDialect
   MLIRXeGPUDialect
   )
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 2cf21fb802ba3..6f0b02897d271 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -11,6 +11,8 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
@@ -404,3 +406,18 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(
     (void)mlir::applyPartialConversion(op, target, std::move(patterns));
   }
 }
+
+std::optional<std::string> xegpu::getXeGPUChipStr(Operation *op) {
+  auto gpuModuleOp = op->getParentOfType<mlir::gpu::GPUModuleOp>();
+  if (gpuModuleOp) {
+    auto targetAttrs = gpuModuleOp.getTargets();
+    if (targetAttrs) {
+      for (auto &attr : *targetAttrs) {
+        auto xevmAttr = llvm::dyn_cast<mlir::xevm::XeVMTargetAttr>(attr);
+        if (xevmAttr)
+          return xevmAttr.getChip().str();
+      }
+    }
+  }
+  return std::nullopt;
+}
\ No newline at end of file
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
index d1e5a62ad3e9b..920f88463732c 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
@@ -1,50 +1,97 @@
-// RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s
+// RUN: mlir-opt %s --xevm-attach-target='module=xevm_* O=3 chip=pvc' -convert-vector-to-xegpu -split-input-file | FileCheck %s --check-prefix=LOAD_ND
+// RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s --check-prefix=LOAD_GATHER
 
-func.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vector<8xf32> {
+gpu.module @xevm_module {
+gpu.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vector<8xf32> {
   %c0 = arith.constant 0.0 : f32
   %0 = vector.transfer_read %source[%offset, %offset, %offset], %c0
     {in_bounds = [true]} : memref<8x16x32xf32>, vector<8xf32>
-  return %0 : vector<8xf32>
+  gpu.return %0 : vector<8xf32>
 }
 
-// CHECK-LABEL: @load_1D_vector(
-// CHECK-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
-// CHECK-SAME:  %[[OFFSET:.+]]: index
-// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
-// CHECK-SAME:    memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
-// CHECK-SAME:    boundary_check = false
-// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf32>
-// CHECK:       return %[[VEC]]
+// CHECK-LABEL: LOAD_ND: @load_1D_vector(
+// CHECK-SAME:  LOAD_ND: %[[SRC:.+]]: memref<8x16x32xf32>,
+// CHECK-SAME:  LOAD_ND: %[[OFFSET:.+]]: index
+// CHECK:       LOAD_ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
+// CHECK-SAME:  LOAD_ND:   %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME:  LOAD_ND:   memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
+// CHECK-SAME:  LOAD_ND:   boundary_check = false
+// CHECK:       LOAD_ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf32>
+// CHECK:       LOAD_ND: return %[[VEC]]
+
+// CHECK-LABEL: LOAD_GATHER: @load_1D_vector(
+// CHECK-SAME:  LOAD_GATHER: %[[SRC:.+]]: memref<8x16x32xf32>,
+// CHECK-SAME:  LOAD_GATHER: %[[OFFSET:.+]]: index
+// CHECK:       LOAD_GATHER: %[[CST:.+]] = arith.constant dense<true> : vector<8xi1>
+// CHECK:       LOAD_GATHER: %[[C32:.+]] = arith.constant 32 : index
+// CHECK:       LOAD_GATHER: %[[C512:.+]] = arith.constant 512 : index
+// CHECK:       LOAD_GATHER: %[[STEP:.+]] = vector.step : vector<8xindex>
+// CHECK:       LOAD_GATHER: %[[MUL1:.+]] = arith.muli %[[OFFSET]], %[[C512]] : index
+// CHECK:       LOAD_GATHER: %[[MUL2:.+]] = arith.muli %[[OFFSET]], %[[C32]] : index
+// CHECK:       LOAD_GATHER: %[[ADD1:.+]] = arith.addi %[[MUL1]], %[[MUL2]] : index
+// CHECK:       LOAD_GATHER: %[[ADD2:.+]] = arith.addi %[[ADD1]], %[[OFFSET]] : index
+// CHECK:       LOAD_GATHER: %[[SPLAT:.+]] = vector.splat %[[ADD2]] : vector<8xindex>
+// CHECK:       LOAD_GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[STEP]] : vector<8xindex>
+// CHECK:       LOAD_GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
+// CHECK:       LOAD_GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<4096xf32>, vector<8xindex>, vector<8xi1> -> vector<8xf32>
+// CHECK:       LOAD_GATHER: return %[[VEC]]
 
-// -----
+}
 
-func.func @load_2D_vector(%source: memref<8x16x32xf32>,
+// -----
+gpu.module @xevm_module {
+gpu.func @load_2D_vector(%source: memref<8x16x32xf32>,
     %offset: index) -> vector<8x16xf32> {
   %c0 = arith.constant 0.0 : f32
   %0 = vector.transfer_read %source[%offset, %offset, %offset], %c0
     {in_bounds = [true, true]} : memref<8x16x32xf32>, vector<8x16xf32>
-  return %0 : vector<8x16xf32>
+  gpu.return %0 : vector<8x16xf32>
 }
 
-// CHECK-LABEL: @load_2D_vector(
-// CHECK-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
-// CHECK-SAME:  %[[OFFSET:.+]]: index
-// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
-// CHECK-SAME:    memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
-// CHECK-SAME:    boundary_check = false
-// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
-// CHECK:       return %[[VEC]]
+// CHECK-LABEL: LOAD_ND: @load_2D_vector(
+// CHECK-SAME:  LOAD_ND: %[[SRC:.+]]: memref<8x16x32xf32>,
+// CHECK-SAME:  LOAD_ND: %[[OFFSET:.+]]: index
+// CHECK:       LOAD_ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
+// CHECK-SAME:  LOAD_ND:   %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME:  LOAD_ND:   memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
+// CHECK-SAME:  LOAD_ND:   boundary_check = false
+// CHECK:       LOAD_ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// CHECK:       LOAD_ND: return %[[VEC]]
+
+// CHECK-LABEL: LOAD_GATHER: @load_2D_vector(
+// CHECK-SAME:  LOAD_GATHER: %[[SRC:.+]]: memref<8x16x32xf32>,
+// CHECK-SAME:  LOAD_GATHER: %[[OFFSET:.+]]: index
+// CHECK:       LOAD_GATHER: %[[CST:.+]] = arith.constant dense<true> : vector<8x16xi1>
+// CHECK:       LOAD_GATHER: %[[C32:.+]] = arith.constant 32 : index
+// CHECK:       LOAD_GATHER: %[[C512:.+]] = arith.constant 512 : index
+// CHECK:       LOAD_GATHER: %[[CST_0:.+]] = arith.constant dense<32> : vector<8xindex>
+// CHECK:       LOAD_GATHER: %[[STEP0:.+]] = vector.step : vector<8xindex>
+// CHECK:       LOAD_GATHER: %[[STEP1:.+]] = vector.step : vector<16xindex>
+// CHECK:       LOAD_GATHER: %[[MUL:.+]] = arith.muli %[[STEP0]], %[[CST_0]] : vector<8xindex>
+// CHECK:       LOAD_GATHER: %[[SHAPE0:.+]] = vector.shape_cast %[[MUL]] : vector<8xindex> to vector<8x1xindex>
+// CHECK:       LOAD_GATHER: %[[SHAPE1:.+]] = vector.shape_cast %[[STEP1]] : vector<16xindex> to vector<1x16xindex>
+// CHECK:       LOAD_GATHER: %[[BROADCAST0:.+]] = vector.broadcast %[[SHAPE0]] : vector<8x1xindex> to vector<8x16xindex>
+// CHECK:       LOAD_GATHER: %[[BROADCAST1:.+]] = vector.broadcast %[[SHAPE1]] : vector<1x16xindex> to vector<8x16xindex>
+// CHECK:       LOAD_GATHER: %[[ADD_VEC:.+]] = arith.addi %[[BROADCAST0]], %[[BROADCAST1]] : vector<8x16xindex>
+// CHECK:       LOAD_GATHER: %[[MUL1:.+]] = arith.muli %[[OFFSET]], %[[C512]] : index
+// CHECK:       LOAD_GATHER: %[[MUL2:.+]] = arith.muli %[[OFFSET]], %[[C32]] : index
+// CHECK:       LOAD_GATHER: %[[ADD1:.+]] = arith.addi %[[MUL1]], %[[MUL2]] : index
+// CHECK:       LOAD_GATHER: %[[ADD2:.+]] = arith.addi %[[ADD1]], %[[OFFSET]] : index
+// CHECK:       LOAD_GATHER: %[[SPLAT:.+]] = vector.splat %[[ADD2]] : vector<8x16xindex>
+// CHECK:       LOAD_GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[ADD_VEC]] : vector<8x16xindex>
+// CHECK:       LOAD_GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
+// CHECK:       LOAD_GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<4096xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
+// CHECK:       LOAD_GATHER: return %[[VEC]]
+}
 
 // -----
-
-func.func @load_zero_pad_out_of_bounds(%source: memref<32x64xf32>,
+gpu.module @xevm_module {
+gpu.func @load_zero_pad_out_of_bounds(%source: memref<32x64xf32>,
     %offset: index) -> vector<8x16xf32> {
   %c0 = arith.constant 0.0 : f32
   %0 = vector.transfer_read %source[%offset, %offset], %c0
     {in_bounds = [false, true]} : memref<32x64xf32>, vector<8x16xf32>
-  return %0 : vector<8x16xf32>
+  gpu.return %0 : vector<8x16xf32>
 }
 
 // CHECK-LABEL: @load_zero_pad_out_of_bounds(
@@ -54,16 +101,17 @@ func.func @load_zero_pad_out_of_bounds(%source: memref<32x64xf32>,
 // CHECK-SAME:    memref<32x64xf32> -> !xegpu.tensor_desc<8x16xf32>
 // CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
 // CHECK:       return %[[VEC]]
+}
 
 // -----
-
-func.func @load_transposed(%source: memref<32x64xf32>,
+gpu.module @xevm_module {
+gpu.func @load_transposed(%source: memref<32x64xf32>,
     %offset: index) -> vector<8x16xf32> {
   %c0 = arith.constant 0.0 : f32
   %0 = vector.transfer_read %source[%offset, %offset], %c0
     {permutation_map = affine_map<(d0, d1) -> (d1, d0)>,
     in_bounds = [true, true]} : memref<32x64xf32>, vector<8x16xf32>
-  return %0 : vector<8x16xf32>
+  gpu.return %0 : vector<8x16xf32>
 }
 
 // CHECK-LABEL: @load_transposed(
@@ -74,15 +122,16 @@ func.func @load_transposed(%source: memref<32x64xf32>,
 // CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]] <{transpose = array<i64: 1, 0>}>
 // CHECK-SAME:    -> vector<8x16xf32>
 // CHECK:       return %[[VEC]]
+}
 
 // -----
-
-func.func @load_dynamic_source(%source: memref<?x?x?xf32>,
+gpu.module @xevm_module {
+gpu.func @load_dynamic_source(%source: memref<?x?x?xf32>,
     %offset: index) -> vector<8x16xf32> {
   %c0 = arith.constant 0.0 : f32
   %0 = vector.transfer_read %source[%offset, %offset, %offset], %c0
     {in_bounds = [true, true]} : memref<?x?x?xf32>, vector<8x16xf32>
-  return %0 : vector<8x16xf32>
+  gpu.return %0 : vector<8x16xf32>
 }
 
 // CHECK-LABEL: @load_dynamic_source(
@@ -100,113 +149,121 @@ func.func @load_dynamic_source(%source: memref<?x?x?xf32>,
 // CHECK-SAME:    memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32
 // CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
 // CHECK:       return %[[VEC]]
-
+}
 // -----
-
-func.func @no_load_out_of_bounds_non_zero_pad(%source: memref<32x64xf32>,
+gpu.module @xevm_module {
+gpu.func @no_load_out_of_bounds_non_zero_pad(%source: memref<32x64xf32>,
     %offset: index, %arg2: index, %pad: f32) -> (vector<8x16xf32>, vector<8x16xf32>) {
   %c1 = arith.constant 1.0 : f32
   %0 = vector.transfer_read %source[%offset, %arg2], %c1
     {in_bounds = [true, false]} : memref<32x64xf32>, vector<8x16xf32>
   %1 = vector.transfer_read %source[%arg2, %offset], %pad
     {in_bounds = [false, true]} : memref<32x64xf32>, vector<8x16xf32>
-  return %0, %1 : vector<8x16xf32>, vector<8x16xf32>
+  gpu.return %0, %1 : vector<8x16xf32>, vector<8x16xf32>
 }
 
 // CHECK-LABEL:   @no_load_out_of_bounds_non_zero_pad(
 // CHECK-COUNT-2: vector.transfer_read
+}
 
 // -----
-
-func.func @no_load_out_of_bounds_1D_vector(%source: memref<8x16x32xf32>,
+gpu.module @xevm_module {
+gpu.func @no_load_out_of_bounds_1D_vector(%source: memref<8x16x32xf32>,
     %offset: index) -> vector<8xf32> {
   %c0 = arith.constant 0.0 : f32
   %0 = vector.transfer_read %source[%offset, %offset, %offset], %c0
     {in_bounds = [false]} : memref<8x16x32xf32>, vector<8xf32>
-  return %0 : vector<8xf32>
+  gpu.return %0 : vector<8xf32>
 }
 
 // CHECK-LABEL: @no_load_out_of_bounds_1D_vector(
 // CHECK:       vector.transfer_read
-
+}
 // -----
-
-func.func @no_load_masked(%source : memref<4xf32>,
+gpu.module @xevm_module {
+gpu.func @no_load_masked(%source : memref<4xf32>,
     %offset : index) -> vector<4xf32> {
   %c0 = arith.constant 0.0 : f32
   %mask = arith.constant dense<[0, 1, 0, 1]> : vector<4xi1>
   %0 = vector.transfer_read %source[%offset], %c0, %mask
     {in_bounds = [true]} : memref<4xf32>, vector<4xf32>
-  return %0 : vector<4xf32>
+  gpu.return %0 : vector<4xf32>
 }
 
 // CHECK-LABEL: @no_load_masked(
 // CHECK:       vector.transfer_read
+}
 
 // -----
-
-func.func @no_load_tensor(%source: tensor<32x64xf32>,
+gpu.module @xevm_module {
+gpu.func @no_load_tensor(%source: tensor<32x64xf32>,
     %offset: index, %arg2: index) -> vector<8x16xf32> {
   %c0 = arith.constant 0.0 : f32
   %0 = vector.transfer_read %source[%offset, %arg2], %c0
     {in_bounds = [true, true]} : tensor<32x64xf32>, vector<8x16xf32>
-  return %0 : vector<8x16xf32>
+  gpu.return %0 : vector<8x16xf32>
 }
 
 // CHECK-LABEL: @no_load_tensor(
 // CHECK:       vector.transfer_read
+}
 
 // -----
-
-func.func @no_load_high_dim_vector(%source: memref<16x32x64xf32>,
+gpu.module @xevm_module {
+gpu.func @no_load_high_dim_vector(%source: memref<16x32x64xf32>,
     %offset: index, %arg2: index) -> vector<8x16x32xf32> {
   %c0 = arith.constant 0.0 : f32
   %0 = vector.transfer_read %source[%offset, %arg2, %offset], %c0
     {in_bounds = [true, true, true]} : memref<16x32x64xf32>, vector<8x16x32xf32>
-  return %0 : vector<8x16x32xf32>
+  gpu.return %0 : vector<8x16x32xf32>
 }
 
 // CHECK-LABEL: @no_load_high_dim_vector(
 // CHECK:       vector.transfer_read
+}
 
 // -----
-
-func.func @no_load_non_unit_inner_stride(
+gpu.module @xevm_module {
+gpu.func @no_load_non_unit_inner_stride(
     %source: memref<32xf32, strided<[?], offset: ?>>,
     %offset: index) -> vector<8xf32> {
   %c0 = arith.constant 0.0 : f32
   %0 = vector.transfer_read %source[%offset], %c0 {in_bounds = [true]}
     : memref<32xf32, strided<[?], offset: ?>>, vector<8xf32>
-  return %0 : vector<8xf32>
+  gpu.return %0 : vector<8xf32>
 }
 
 // CHECK-LABEL: @no_load_non_unit_inner_stride(
 // CHECK:       vector.transfer_read
+}
 
 // -----
-
-func.func @no_load_unsupported_map(%source: memref<16x32x64xf32>,
+gpu.module @xevm_module {
+gpu.func @no_load_unsupported_map(%source: memref<16x32x64xf32>,
     %offset: index) -> vector<8x16xf32> {
   %c0 = arith.constant 0.0 : f32
   %0 = vector.transfer_read %source[%offset, %offset, %offset], %c0
     {permutation_map = affine_map<(d0, d1, d2) -> (d0, d2)>,
     in_bounds = [true, true]} : memref<16x32x64xf32>, vector<8x16xf32>
-  return %0 : vector<8x16xf32>
+  gpu.return %0 : vector<8x16xf32>
 }
 
 // CHECK-LABEL: @no_load_unsupported_map(
 // CHECK:       vector.transfer_read
+}
 
 // -----
-
-func.func @no_load_transpose_unsupported_data_type(%source: memref<32x64xf16>,
+gpu.module @xevm_module {
+gpu.func @no_load_transpose_unsupported_data_type(%source: memref<32x64xf16>,
     %offset: index) -> vector<8x16xf16> {
   %c0 = arith.constant 0.0 : f16
   %0 = vector.transfer_read %source[%offset, %offset], %c0
     {permutation_map = affine_map<(d0, d1) -> (d1, d0)>,
     in_bounds = [true, true]} : memref<32x64xf16>, vector<8x16xf16>
-  return %0 : vector<8x16xf16>
+  gpu.return %0 : vector<8x16xf16>
 }
 
 // CHECK-LABEL: @no_load_transpose_unsupported_data_type(
 // CHECK:       vector.transfer_read
+}
+
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
index d5f1221aebed5..e244995dd5817 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
@@ -1,11 +1,12 @@
-// RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s
+// RUN: mlir-opt %s --xevm-attach-target='module=xevm.* O=3 chip=pvc' -convert-vector-to-xegpu -split-input-file | FileCheck %s
 
-func.func @store_1D_vector(%vec: vector<8xf32>,
+gpu.module @xevm_module {
+gpu.func @store_1D_vector(%vec: vector<8xf32>,
     %source: memref<8x16x32xf32>, %offset: index) {
   vector.transfer_write %vec, %source[%offset, %offset, %offset]
     {in_bounds = [true]}
     : vector<8xf32>, memref<8x16x32xf32>
-  return
+  gpu.return
 }
 
 // CHECK-LABEL: @store_1D_vector(
@@ -17,15 +18,16 @@ func.func @store_1D_vector(%vec: vector<8xf32>,
 // CHECK-SAME:    memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
 // CHECK-SAME:    boundary_check = false
 // CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8xf32>
+}
 
 // -----
-
-func.func @store_2D_vector(%vec: vector<8x16xf32>,
+gpu.module @xevm_module {
+gpu.func @store_2D_vector(%vec: vector<8x16xf32>,
     %source: memref<8x16x32xf32>, %offset: index) {
   vector.transfer_write %vec, %source[%offset, %offset, %offset]
     {in_bounds = [true, true]}
     : vector<8x16xf32>, memref<8x16x32xf32>
-  return
+  gpu.return
 }
 
 // CHECK-LABEL: @store_2D_vector(
@@ -37,15 +39,16 @@ func.func @store_2D_vector(%vec: vector<8x16xf32>,
 // CHECK-SAME:    memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
 // CHECK-SAME:    boundary_check = false
 // CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+}
 
 // -----
-
-func.func @store_dynamic_source(%vec: vector<8x16xf32>,
+gpu.module @xevm_module {
+gpu.func @store_dynamic_source(%vec: vector<8x16xf32>,
     %source: memref<?x?x?xf32>, %offset: index) {
   vector.transfer_write %vec, %source[%offset, %offset, %offset]
     {in_bounds = [true, true]}
     : vector<8x16xf32>, memref<?x?x?xf32>
-  return
+  gpu.return
 }
 
 // CHECK-LABEL: @store_dynamic_source(
@@ -63,15 +66,16 @@ func.func @store_dynamic_source(%vec: vector<8x16xf32>,
 // CHECK-SAME:  , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
 // CHECK-SAME:    memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32
 // CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+}
 
 // -----
-
-func.func @store_out_of_bounds(%vec: vector<8x16xf32>,
+gpu.module @xevm_module {
+gpu.func @store_out_of_bounds(%vec: vector<8x16xf32>,
     %source: memref<7x64xf32>, %offset: index) {
   vector.transfer_write %vec, %source[%offset, %offset]
     {in_bounds = [false, true]}
     : vector<8x16xf32>, memref<7x64xf32>
-  return
+  gpu.return
 }
 
 // CHECK-LABEL:   @store_out_of_bounds(
@@ -82,97 +86,105 @@ func.func @store_out_of_bounds(%vec: vector<8x16xf32>,
 // CHECK-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
 // CHECK-SAME:    memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32>
 // CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+}
 
 // -----
-
-func.func @no_store_transposed(%vec: vector<8x16xf32>,
+gpu.module @xevm_module {
+gpu.func @no_store_transposed(%vec: vector<8x16xf32>,
     %source: memref<32x64xf32>, %offset: index) {
   vector.transfer_write %vec, %source[%offset, %offset]
     {permutation_map = affine_map<(d0, d1) -> (d1, d0)>,
     in_bounds = [true, true]}
     : vector<8x16xf32>, memref<32x64xf32>
-  return
+  gpu.return
 }
 
 // CHECK-LABEL: @no_store_transposed(
 // CHECK:       vector.transfer_write
+}
 
 // -----
-
-func.func @no_store_masked(%vec: vector<4xf32>,
+gpu.module @xevm_module {
+gpu.func @no_store_masked(%vec: vector<4xf32>,
     %source: memref<4xf32>, %offset: index) {
   %mask = arith.constant dense<[0, 1, 0, 1]> : vector<4xi1>
   vector.transfer_write %vec, %source[%offset], %mask
     {in_bounds = [true]}
     : vector<4xf32>, memref<4xf32>
-  return
+  gpu.return
 }
 
 // CHECK-LABEL: @no_store_masked(
 // CHECK:       vector.transfer_write
+}
 
 // -----
-
-func.func @no_store_tensor(%vec: vector<8x16xf32>,
+gpu.module @xevm_module {
+gpu.func @no_store_tensor(%vec: vector<8x16xf32>,
     %source: tensor<32x64xf32>, %offset: index) -> tensor<32x64xf32> {
   %0 = vector.transfer_write %vec, %source[%offset, %offset]
     {in_bounds = [true, true]}
     : vector<8x16xf32>, tensor<32x64xf32>
-  return %0 : tensor<32x64xf32>
+  gpu.return %0 : tensor<32x64xf32>
 }
 
 // CHECK-LABEL: @no_store_tensor(
 // CHECK:       vector.transfer_write
+}
 
 // -----
-
-func.func @no_store_high_dim_vector(%vec: vector<8x16x32xf32>,
+gpu.module @xevm_module {
+gpu.func @no_store_high_dim_vector(%vec: vector<8x16x32xf32>,
     %source: memref<16x32x64xf32>, %offset: index) {
   vector.transfer_write %vec, %source[%offset, %offset, %offset]
     {in_bounds = [true, true, true]}
     : vector<8x16x32xf32>, memref<16x32x64xf32>
-  return
+  gpu.return
 }
 
 // CHECK-LABEL: @no_store_high_dim_vector(
 // CHECK:       vector.transfer_write
+}
 
 // -----
-
-func.func @no_store_non_unit_inner_stride(%vec: vector<8xf32>,
+gpu.module @xevm_module {
+gpu.func @no_store_non_unit_inner_stride(%vec: vector<8xf32>,
     %source: memref<32xf32, strided<[?], offset: ?>>, %offset: index) {
   vector.transfer_write %vec, %source[%offset]
     {in_bounds = [true]}
     : vector<8xf32>, memref<32xf32, strided<[?], offset: ?>>
-  return
+  gpu.return
 }
 
 // CHECK-LABEL: @no_store_non_unit_inner_stride(
 // CHECK:       vector.transfer_write
+}
 
 // -----
-
-func.func @no_store_unsupported_map(%vec: vector<8x16xf32>,
+gpu.module @xevm_module {
+gpu.func @no_store_unsupported_map(%vec: vector<8x16xf32>,
     %source: memref<16x32x64xf32>, %offset: index) {
   vector.transfer_write %vec, %source[%offset, %offset, %offset]
     {permutation_map = affine_map<(d0, d1, d2) -> (d0, d2)>,
     in_bounds = [true, true]}
     : vector<8x16xf32>, memref<16x32x64xf32>
-  return
+  gpu.return
 }
 
 // CHECK-LABEL: @no_store_unsupported_map(
 // CHECK:       vector.transfer_write
+}
 
 // -----
-
-func.func @no_store_out_of_bounds_1D_vector(%vec: vector<8xf32>,
+gpu.module @xevm_module {
+gpu.func @no_store_out_of_bounds_1D_vector(%vec: vector<8xf32>,
     %source: memref<8x16x32xf32>, %offset: index) {
   vector.transfer_write %vec, %source[%offset, %offset, %offset]
     {in_bounds = [false]}
     : vector<8xf32>, memref<8x16x32xf32>
-  return
+  gpu.return
 }
 
 // CHECK-LABEL: @no_store_out_of_bounds_1D_vector(
 // CHECK:       vector.transfer_write
+}
\ No newline at end of file

>From 4f17c69394523467d4255604acc217dbd275feba Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 1 Aug 2025 05:54:04 +0000
Subject: [PATCH 03/17] enable permutation mapping

---
 .../VectorToXeGPU/VectorToXeGPU.cpp           | 38 ++++++++++++++++++-
 1 file changed, 37 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index c22a709c45d46..2493026dfc548 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -343,6 +343,37 @@ static LogicalResult createStoreScatter(vector::TransferWriteOp writeOp,
   return success();
 }
 
+static void adjustStridesForPermutation(vector::TransferReadOp readOp,
+                                        PatternRewriter &rewriter,
+                                        MemRefType memrefType,
+                                        SmallVectorImpl<int64_t> &strides) {
+  AffineMap permMap = readOp.getPermutationMap();
+  if (!permMap.isMinorIdentity()) {
+    SmallVector<int64_t> adjustedStrides;
+    unsigned vecRank = readOp.getVectorType().getRank();
+    unsigned memrefRank = memrefType.getRank();
+    // Only adjust the last vecRank strides according to the permutation
+    ArrayRef<int64_t> relevantStrides = ArrayRef<int64_t>(strides).take_back(vecRank);
+    for (AffineExpr expr : permMap.getResults().take_back(vecRank)) {
+      auto dimExpr = dyn_cast<AffineDimExpr>(expr);
+      if (!dimExpr) {
+        rewriter.notifyMatchFailure(readOp, "Unsupported permutation expr");
+        return;
+      }
+      unsigned pos = dimExpr.getPosition();
+      // Map permutation to the relevant strides (innermost dims)
+      if (pos < memrefRank - vecRank) {
+        rewriter.notifyMatchFailure(readOp, "Permutation out of bounds");
+        return;
+      }
+      adjustedStrides.push_back(relevantStrides[pos - (memrefRank - vecRank)]);
+    }
+    // Replace the last vecRank strides with the adjusted ones
+    for (unsigned i = 0; i < vecRank; ++i)
+      strides[memrefRank - vecRank + i] = adjustedStrides[i];
+  }
+}
+
 LogicalResult lowerTransferReadToLoadOp(vector::TransferReadOp readOp,
                                         PatternRewriter &rewriter) {
 
@@ -358,10 +389,15 @@ LogicalResult lowerTransferReadToLoadOp(vector::TransferReadOp readOp,
   if (failed(memrefType.getStridesAndOffset(strides, offset)))
     return rewriter.notifyMatchFailure(readOp, "Failed to get memref strides");
 
+  // Adjust strides according to the permutation map (e.g., for transpose)
+  adjustStridesForPermutation(readOp, rewriter, memrefType, strides);
+
   Value localOffsets =
       computeGatherOffsets(readOp, rewriter, strides, vectorType);
+
   Value flatMemref =
       collapseMemrefTo1D(readOp, rewriter, memrefType, elementType);
+  
   return createLoadGather(readOp, rewriter, flatMemref, localOffsets,
                           vectorType);
 }
@@ -415,7 +451,7 @@ extraCheckForScatteredLoadStore(vector::TransferReadOp readOp,
             readOp, "Vector rank cannot exceed memref rank");
       // Compare the last vectorRank dimensions of memref with vector shape
       for (size_t i = 0; i < vectorRank; ++i) {
-        if (memrefShape[memrefRank - vectorRank + i] != vectorShape[i])
+        if (memrefShape[memrefRank - vectorRank + i] <= vectorShape[i])
           return rewriter.notifyMatchFailure(
               readOp, "Memref lower dimensions must match vector shape");
       }

>From ebf51349b257831d9c0395a2f33f4abf795310c2 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Sat, 2 Aug 2025 04:47:24 +0000
Subject: [PATCH 04/17] adding dynamic shape support

---
 .../VectorToXeGPU/VectorToXeGPU.cpp           | 217 +++++++++++-------
 .../VectorToXeGPU/transfer-read-to-xegpu.mlir | 150 ++++++++++--
 2 files changed, 267 insertions(+), 100 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 2493026dfc548..20c11198d67c8 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -69,11 +69,6 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
   if (!srcTy)
     return rewriter.notifyMatchFailure(xferOp, "Expects memref source");
 
-  // Perform common data transfer checks.
-  VectorType vecTy = xferOp.getVectorType();
-  if (failed(storeLoadPreconditions(rewriter, xferOp, vecTy)))
-    return failure();
-
   // Validate further transfer op semantics.
   SmallVector<int64_t> strides;
   int64_t offset;
@@ -81,6 +76,7 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
     return rewriter.notifyMatchFailure(
         xferOp, "Buffer must be contiguous in the innermost dimension");
 
+  VectorType vecTy = xferOp.getVectorType();
   unsigned vecRank = vecTy.getRank();
   if (xferOp.hasOutOfBoundsDim() && vecRank < 2)
     return rewriter.notifyMatchFailure(
@@ -156,6 +152,93 @@ createNdDescriptor(PatternRewriter &rewriter, Location loc,
   return ndDesc;
 }
 
+static void adjustStridesForPermutation(Operation *op,
+                                        PatternRewriter &rewriter,
+                                        MemRefType memrefType,
+                                        AffineMap permMap, VectorType vecType,
+                                        SmallVectorImpl<Value> &strides) {
+  unsigned vecRank;
+  unsigned memrefRank = memrefType.getRank();
+
+  if (!permMap.isMinorIdentity()) {
+    vecRank = vecType.getRank();
+    // Only adjust the last vecRank strides according to the permutation
+    ArrayRef<Value> relevantStrides =
+        ArrayRef<Value>(strides).take_back(vecRank);
+    SmallVector<Value> adjustedStrides(vecRank);
+    // For each output dimension in the permutation map, find which input dim it
+    // refers to, and assign the corresponding stride.
+    for (unsigned outIdx = 0; outIdx < vecRank; ++outIdx) {
+      AffineExpr expr = permMap.getResult(outIdx);
+      auto dimExpr = dyn_cast<AffineDimExpr>(expr);
+      if (!dimExpr) {
+        rewriter.notifyMatchFailure(op, "Unsupported permutation expr");
+        return;
+      }
+      unsigned pos = dimExpr.getPosition();
+      // Map permutation to the relevant strides (innermost dims)
+      if (pos < memrefRank - vecRank) {
+        rewriter.notifyMatchFailure(op, "Permutation out of bounds");
+        return;
+      }
+      // The stride for output dimension outIdx is the stride of input dimension
+      // pos
+      adjustedStrides[outIdx] = relevantStrides[pos - (memrefRank - vecRank)];
+    }
+    // Replace the last vecRank strides with the adjusted ones
+    for (unsigned i = 0; i < vecRank; ++i)
+      strides[memrefRank - vecRank + i] = adjustedStrides[i];
+  }
+}
+
+SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
+                                  PatternRewriter &rewriter) {
+  SmallVector<Value> strides;
+  Value baseMemref = xferOp.getBase();
+  AffineMap permMap = xferOp.getPermutationMap();
+  VectorType vectorType = xferOp.getVectorType();
+  MemRefType memrefType = llvm::cast<MemRefType>(baseMemref.getType());
+
+  Location loc = xferOp.getLoc();
+  if (memrefType.hasStaticShape()) {
+    int64_t offset;
+    SmallVector<int64_t> intStrides;
+    if (failed(memrefType.getStridesAndOffset(intStrides, offset))) {
+      rewriter.notifyMatchFailure(xferOp, "Failed to get memref strides");
+      return {};
+    }
+    // Wrap static strides as MLIR values
+    for (int64_t s : intStrides)
+      strides.push_back(rewriter.create<arith::ConstantIndexOp>(loc, s));
+  } else {
+    // For dynamic shape memref, use memref.extract_strided_metadata to get
+    // stride values
+    unsigned rank = memrefType.getRank();
+    Type indexType = rewriter.getIndexType();
+
+    // Result types: [base_memref, offset, stride0, stride1, ..., strideN-1,
+    // size0, size1, ..., sizeN-1]
+    SmallVector<Type> resultTypes;
+    resultTypes.push_back(MemRefType::get(
+        {}, memrefType.getElementType())); // base memref (unranked)
+    resultTypes.push_back(indexType);      // offset
+    for (unsigned i = 0; i < rank; ++i) {
+      resultTypes.push_back(indexType); // strides
+    }
+    for (unsigned i = 0; i < rank; ++i) {
+      resultTypes.push_back(indexType); // sizes
+    }
+
+    auto meta = rewriter.create<memref::ExtractStridedMetadataOp>(
+        loc, resultTypes, baseMemref);
+    strides.append(meta.getStrides().begin(), meta.getStrides().end());
+  }
+  // Adjust strides according to the permutation map (e.g., for transpose)
+  adjustStridesForPermutation(xferOp, rewriter, memrefType, permMap, vectorType,
+                              strides);
+  return strides;
+}
+
 // This function lowers vector.transfer_read to XeGPU load operation.
   // Example:
   //   %0 = vector.transfer_read %expand_shape[%block_id_y, %c0, %c0, %c0, %c0], 
@@ -191,11 +274,13 @@ createNdDescriptor(PatternRewriter &rewriter, Location loc,
 //                           vector<4x2x6x32xindex> -> vector<4x2x6x32xbf16>
 
 // Compute localOffsets for load_gather and store_scatter
-static Value computeGatherOffsets(vector::TransferReadOp readOp,
+static Value computeGatherOffsets(VectorTransferOpInterface xferOp,
                                   PatternRewriter &rewriter,
-                                  ArrayRef<int64_t> strides,
-                                  VectorType vectorType) {
-  Location loc = readOp.getLoc();
+                                  ArrayRef<Value> strides) {
+  Location loc = xferOp.getLoc();
+  VectorType vectorType = xferOp.getVectorType();
+  SmallVector<Value> indices(xferOp.getIndices().begin(),
+                             xferOp.getIndices().end());
   ArrayRef<int64_t> vectorShape = vectorType.getShape();
 
   // Step 1: Create vector.step operations for each dimension
@@ -212,12 +297,11 @@ static Value computeGatherOffsets(vector::TransferReadOp readOp,
   SmallVector<Value> strideMultiplied;
   for (size_t i = 0; i < vectorRank; ++i) {
     size_t memrefDim = memrefRank - vectorRank + i;
-    int64_t stride = strides[memrefDim];
-    Value strideConstant = rewriter.create<arith::ConstantIndexOp>(loc, stride);
+    Value strideValue = strides[memrefDim];
     auto mulType = llvm::cast<VectorType>(stepVectors[i].getType());
     auto mulOp = rewriter.create<arith::MulIOp>(
         loc, stepVectors[i],
-        rewriter.create<vector::SplatOp>(loc, strideConstant, mulType));
+        rewriter.create<vector::BroadcastOp>(loc, mulType, strideValue));
     strideMultiplied.push_back(mulOp);
   }
 
@@ -251,31 +335,33 @@ static Value computeGatherOffsets(vector::TransferReadOp readOp,
 
   // Step 6: Compute base offset from transfer read indices
   Value baseOffset = nullptr;
-  auto indices = readOp.getIndices();
   if (!indices.empty()) {
     baseOffset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
     for (size_t i = 0; i < indices.size(); ++i) {
-      Value strideVal =
-          rewriter.create<arith::ConstantIndexOp>(loc, strides[i]);
+      Value strideVal = strides[i];
       Value offsetContrib =
           rewriter.create<arith::MulIOp>(loc, indices[i], strideVal);
       baseOffset =
           rewriter.create<arith::AddIOp>(loc, baseOffset, offsetContrib);
     }
     // Broadcast base offset to match vector shape
-    Value splatBase =
-        rewriter.create<vector::SplatOp>(loc, baseOffset, fullIndexVectorType);
-    localOffsets = rewriter.create<arith::AddIOp>(loc, splatBase, localOffsets);
+    Value bcastBase = rewriter.create<vector::BroadcastOp>(
+        loc, fullIndexVectorType, baseOffset);
+    localOffsets = rewriter.create<arith::AddIOp>(loc, bcastBase, localOffsets);
   }
-
   return localOffsets;
 }
 
 // Collapse memref shape to 1D
-static Value collapseMemrefTo1D(vector::TransferReadOp readOp,
-                                PatternRewriter &rewriter,
-                                MemRefType memrefType, Type elementType) {
-  Location loc = readOp.getLoc();
+static Value collapseMemrefTo1D(VectorTransferOpInterface xferOp,
+                                PatternRewriter &rewriter) {
+  Location loc = xferOp.getLoc();
+
+  Value baseMemref = xferOp.getBase();
+  MemRefType memrefType = llvm::cast<MemRefType>(baseMemref.getType());
+  Type elementType = memrefType.getElementType();
+
+  // Compute the total number of elements in the memref
   int64_t totalElements = 1;
   bool hasDynamicDim = false;
   for (int64_t dim : memrefType.getShape()) {
@@ -301,7 +387,7 @@ static Value collapseMemrefTo1D(vector::TransferReadOp readOp,
   reassociation.push_back(allDims);
 
   auto collapseOp = rewriter.create<memref::CollapseShapeOp>(
-      loc, flatMemrefType, readOp.getBase(), reassociation);
+      loc, flatMemrefType, baseMemref, reassociation);
   return collapseOp;
 }
 
@@ -343,37 +429,6 @@ static LogicalResult createStoreScatter(vector::TransferWriteOp writeOp,
   return success();
 }
 
-static void adjustStridesForPermutation(vector::TransferReadOp readOp,
-                                        PatternRewriter &rewriter,
-                                        MemRefType memrefType,
-                                        SmallVectorImpl<int64_t> &strides) {
-  AffineMap permMap = readOp.getPermutationMap();
-  if (!permMap.isMinorIdentity()) {
-    SmallVector<int64_t> adjustedStrides;
-    unsigned vecRank = readOp.getVectorType().getRank();
-    unsigned memrefRank = memrefType.getRank();
-    // Only adjust the last vecRank strides according to the permutation
-    ArrayRef<int64_t> relevantStrides = ArrayRef<int64_t>(strides).take_back(vecRank);
-    for (AffineExpr expr : permMap.getResults().take_back(vecRank)) {
-      auto dimExpr = dyn_cast<AffineDimExpr>(expr);
-      if (!dimExpr) {
-        rewriter.notifyMatchFailure(readOp, "Unsupported permutation expr");
-        return;
-      }
-      unsigned pos = dimExpr.getPosition();
-      // Map permutation to the relevant strides (innermost dims)
-      if (pos < memrefRank - vecRank) {
-        rewriter.notifyMatchFailure(readOp, "Permutation out of bounds");
-        return;
-      }
-      adjustedStrides.push_back(relevantStrides[pos - (memrefRank - vecRank)]);
-    }
-    // Replace the last vecRank strides with the adjusted ones
-    for (unsigned i = 0; i < vecRank; ++i)
-      strides[memrefRank - vecRank + i] = adjustedStrides[i];
-  }
-}
-
 LogicalResult lowerTransferReadToLoadOp(vector::TransferReadOp readOp,
                                         PatternRewriter &rewriter) {
 
@@ -384,20 +439,12 @@ LogicalResult lowerTransferReadToLoadOp(vector::TransferReadOp readOp,
   VectorType vectorType = readOp.getVectorType();
   Type elementType = vectorType.getElementType();
 
-  SmallVector<int64_t> strides;
-  int64_t offset;
-  if (failed(memrefType.getStridesAndOffset(strides, offset)))
-    return rewriter.notifyMatchFailure(readOp, "Failed to get memref strides");
+  SmallVector<Value> strides = computeStrides(readOp, rewriter);
 
-  // Adjust strides according to the permutation map (e.g., for transpose)
-  adjustStridesForPermutation(readOp, rewriter, memrefType, strides);
+  Value localOffsets = computeGatherOffsets(readOp, rewriter, strides);
 
-  Value localOffsets =
-      computeGatherOffsets(readOp, rewriter, strides, vectorType);
+  Value flatMemref = collapseMemrefTo1D(readOp, rewriter);
 
-  Value flatMemref =
-      collapseMemrefTo1D(readOp, rewriter, memrefType, elementType);
-  
   return createLoadGather(readOp, rewriter, flatMemref, localOffsets,
                           vectorType);
 }
@@ -409,25 +456,21 @@ LogicalResult lowerTransferWriteToStoreOp(vector::TransferWriteOp writeOp,
   if (!memrefType)
     return rewriter.notifyMatchFailure(writeOp, "Expected memref source");
 
+  Value baseMemref = writeOp.getBase();
+  AffineMap permMap = writeOp.getPermutationMap();
   VectorType vectorType = writeOp.getVectorType();
   Type elementType = vectorType.getElementType();
+  SmallVector<Value> indices(writeOp.getIndices().begin(),
+                             writeOp.getIndices().end());
 
-  SmallVector<int64_t> strides;
-  int64_t offset;
-  if (failed(memrefType.getStridesAndOffset(strides, offset)))
-    return rewriter.notifyMatchFailure(writeOp, "Failed to get memref strides");
+  SmallVector<Value> strides = computeStrides(writeOp, rewriter);
 
-  // Compute localOffsets for store_scatter
-  Value localOffsets =
-      computeGatherOffsets(cast<vector::TransferReadOp>(writeOp.getOperation()),
-                           rewriter, strides, vectorType);
+  Value localOffsets = computeGatherOffsets(writeOp, rewriter, strides);
 
-  Value flatMemref =
-      collapseMemrefTo1D(cast<vector::TransferReadOp>(writeOp.getOperation()),
-                         rewriter, memrefType, elementType);
+  Value flatMemref = collapseMemrefTo1D(writeOp, rewriter);
 
-  return createStoreScatter(writeOp, rewriter, flatMemref, localOffsets,
-                            writeOp.getVector(), vectorType);
+  return createStoreScatter(writeOp, rewriter, writeOp.getVector(), flatMemref,
+                            localOffsets, vectorType);
 }
 
 static LogicalResult
@@ -436,8 +479,9 @@ extraCheckForScatteredLoadStore(vector::TransferReadOp readOp,
   // 1. it must be inbound access by checking in_bounds attributes, like
   // {in_bounds = [false, true]}
   if (readOp.hasOutOfBoundsDim())
-    return rewriter.notifyMatchFailure(
-        readOp, "Out-of-bounds access is not supported for this chip");
+    return rewriter.notifyMatchFailure(readOp,
+                                       "Out-of-bounds access is not supported "
+                                       "for scatter load/store lowering");
   // 2. if the memref has static shape, its lower rank must exactly match with
   // vector shape.
   if (auto memrefType = dyn_cast<MemRefType>(readOp.getShapedType())) {
@@ -479,6 +523,11 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
       return lowerTransferReadToLoadOp(readOp, rewriter);
     }
 
+    // Perform common data transfer checks.
+    VectorType vecTy = readOp.getVectorType();
+    if (failed(storeLoadPreconditions(rewriter, readOp, vecTy)))
+      return failure();
+
     bool isOutOfBounds = readOp.hasOutOfBoundsDim();
     if (isOutOfBounds && !isZeroConstant(readOp.getPadding()))
       return rewriter.notifyMatchFailure(
@@ -487,7 +536,6 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
     AffineMap readMap = readOp.getPermutationMap();
     bool isTransposeLoad = !readMap.isMinorIdentity();
 
-    VectorType vecTy = readOp.getVectorType();
     Type elementType = vecTy.getElementType();
     unsigned minTransposeBitWidth = 32;
     if (isTransposeLoad &&
@@ -540,12 +588,15 @@ struct TransferWriteLowering
       // calling another function that lower TransferWriteOp to regular StoreOp
       return lowerTransferWriteToStoreOp(writeOp, rewriter);
     }
+    // Perform common data transfer checks.
+    VectorType vecTy = writeOp.getVectorType();
+    if (failed(storeLoadPreconditions(rewriter, writeOp, vecTy)))
+      return failure();
 
     AffineMap map = writeOp.getPermutationMap();
     if (!map.isMinorIdentity())
       return rewriter.notifyMatchFailure(writeOp, "Expects identity map");
 
-    VectorType vecTy = writeOp.getVectorType();
     auto descType = xegpu::TensorDescType::get(
         vecTy.getShape(), vecTy.getElementType(),
         /*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(),
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
index 920f88463732c..b880a70187353 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
@@ -127,29 +127,145 @@ gpu.func @load_transposed(%source: memref<32x64xf32>,
 // -----
 gpu.module @xevm_module {
 gpu.func @load_dynamic_source(%source: memref<?x?x?xf32>,
-    %offset: index) -> vector<8x16xf32> {
+    %i: index, %j: index, %k: index) -> vector<8x16xf32> {
   %c0 = arith.constant 0.0 : f32
-  %0 = vector.transfer_read %source[%offset, %offset, %offset], %c0
+  %0 = vector.transfer_read %source[%i, %j, %k], %c0
     {in_bounds = [true, true]} : memref<?x?x?xf32>, vector<8x16xf32>
   gpu.return %0 : vector<8x16xf32>
 }
+// CHECK-LABEL: LOAD_ND: @load_dynamic_source(
+// CHECK-SAME:  LOAD_ND: %[[SRC:.+]]: memref<?x?x?xf32>,
+// CHECK-SAME:  LOAD_ND: %[[OFFSET:.+]]: index
+// CHECK:       LOAD_ND: %[[C2:.+]] = arith.constant 2 : index
+// CHECK:       LOAD_ND: %[[C1:.+]] = arith.constant 1 : index
+// CHECK:       LOAD_ND: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG:   LOAD_ND: %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
+// CHECK-DAG:   LOAD_ND: %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
+// CHECK-DAG:   LOAD_ND: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
+// CHECK:       LOAD_ND: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
+// CHECK:       LOAD_ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET:.+]], %[[OFFSET:.+]], %[[OFFSET:.+]]]
+// CHECK:       LOAD_ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// CHECK:       LOAD_ND: return %[[VEC]]
+
+
+// CHECK-LABEL: LOAD_GATHER: @load_dynamic_source(%[[ARG0:.+]]: memref<?x?x?xf32>, %[[ARG1:.+]]: index, %[[ARG2:.+]]: index, %[[ARG3:.+]]: index)
+// CHECK:       LOAD_GATHER: %[[CST:.+]] = arith.constant dense<true> : vector<8x16xi1>
+// CHECK:       LOAD_GATHER: memref.extract_strided_metadata %[[ARG0]]
+// CHECK:       LOAD_GATHER: %[[STEP8:.+]] = vector.step : vector<8xindex>
+// CHECK:       LOAD_GATHER: %[[STEP16:.+]] = vector.step : vector<16xindex>
+// CHECK:       LOAD_GATHER: %[[BSTR1:.+]] = vector.broadcast %[[STR1:.+]] : index to vector<8xindex>
+// CHECK:       LOAD_GATHER: %[[MUL8:.+]] = arith.muli %[[STEP8]], %[[BSTR1]] : vector<8xindex>
+// CHECK:       LOAD_GATHER: %[[SHAPE8:.+]] = vector.shape_cast %[[MUL8]] : vector<8xindex> to vector<8x1xindex>
+// CHECK:       LOAD_GATHER: %[[SHAPE16:.+]] = vector.shape_cast %[[STEP16]] : vector<16xindex> to vector<1x16xindex>
+// CHECK:       LOAD_GATHER: %[[BROAD8:.+]] = vector.broadcast %[[SHAPE8]] : vector<8x1xindex> to vector<8x16xindex>
+// CHECK:       LOAD_GATHER: %[[BROAD16:.+]] = vector.broadcast %[[SHAPE16]] : vector<1x16xindex> to vector<8x16xindex>
+// CHECK:       LOAD_GATHER: %[[ADDVEC:.+]] = arith.addi %[[BROAD8]], %[[BROAD16]] : vector<8x16xindex>
+// CHECK:       LOAD_GATHER: %[[MULI1:.+]] = arith.muli %[[ARG1]], %[[STR0:.+]] : index
+// CHECK:       LOAD_GATHER: %[[MULI2:.+]] = arith.muli %[[ARG2]], %[[STR1]] : index
+// CHECK:       LOAD_GATHER: %[[ADDI1:.+]] = arith.addi %[[MULI1]], %[[MULI2]] : index
+// CHECK:       LOAD_GATHER: %[[ADDI2:.+]] = arith.addi %[[ADDI1]], %[[ARG3]] : index
+// CHECK:       LOAD_GATHER: %[[BROADIDX:.+]] = vector.broadcast %[[ADDI2]] : index to vector<8x16xindex>
+// CHECK:       LOAD_GATHER: %[[FINALIDX:.+]] = arith.addi %[[BROADIDX]], %[[ADDVEC]] : vector<8x16xindex>
+// CHECK:       LOAD_GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2]{{\]}} : memref<?x?x?xf32> into memref<?xf32>
+// CHECK:       LOAD_GATHER: %[[RES:.+]] = xegpu.load %[[COLLAPSE]][%[[FINALIDX]]], %[[CST]] : memref<?xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
+// CHECK:       LOAD_GATHER: gpu.return %[[RES]] : vector<8x16xf32>
+}
+
+// -----
+gpu.module @xevm_module {
+gpu.func @load_dynamic_source2(%source: memref<?x8x16xf32>,
+    %i: index, %j: index, %k: index) -> vector<8x16xf32> {
+  %c0 = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %source[%i, %j, %k], %c0
+    {in_bounds = [true, true]} : memref<?x8x16xf32>, vector<8x16xf32>
+  gpu.return %0 : vector<8x16xf32>
+}
+
+// CHECK-LABEL: LOAD_ND: @load_dynamic_source2(
+// CHECK-DAG:   LOAD_ND: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG:   LOAD_ND: %[[DIM:.+]] = memref.dim %{{.*}}, %[[C0]] : memref<?x8x16xf32>
+// CHECK:       LOAD_ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}], shape : [%[[DIM]], 8, 16], strides : [128, 16, 1] : memref<?x8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
+// CHECK:       LOAD_ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]] : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>> -> vector<8x16xf32>
+// CHECK:       LOAD_ND: return %[[VEC]] : vector<8x16xf32>
+
+// CHECK-LABEL: LOAD_GATHER: @load_dynamic_source2(
+// CHECK-DAG:   LOAD_GATHER: %[[CST:.+]] = arith.constant dense<16> : vector<8xindex>
+// CHECK-DAG:   LOAD_GATHER: %[[CST_0:.+]] = arith.constant dense<true> : vector<8x16xi1>
+// CHECK-DAG:   LOAD_GATHER: %[[C128:.+]] = arith.constant 128 : index
+// CHECK-DAG:   LOAD_GATHER: %[[C16:.+]] = arith.constant 16 : index
+// CHECK-DAG:   LOAD_GATHER: %[[STEP8:.+]] = vector.step : vector<8xindex>
+// CHECK-DAG:   LOAD_GATHER: %[[STEP16:.+]] = vector.step : vector<16xindex>
+// CHECK-DAG:   LOAD_GATHER: %[[MUL8:.+]] = arith.muli %[[STEP8]], %[[CST]] : vector<8xindex>
+// CHECK-DAG:   LOAD_GATHER: %[[SHAPE8:.+]] = vector.shape_cast %[[MUL8]] : vector<8xindex> to vector<8x1xindex>
+// CHECK-DAG:   LOAD_GATHER: %[[SHAPE16:.+]] = vector.shape_cast %[[STEP16]] : vector<16xindex> to vector<1x16xindex>
+// CHECK-DAG:   LOAD_GATHER: %[[BCAST8:.+]] = vector.broadcast %[[SHAPE8]] : vector<8x1xindex> to vector<8x16xindex>
+// CHECK-DAG:   LOAD_GATHER: %[[BCAST16:.+]] = vector.broadcast %[[SHAPE16]] : vector<1x16xindex> to vector<8x16xindex>
+// CHECK-DAG:   LOAD_GATHER: %[[ADDIDX:.+]] = arith.addi %[[BCAST8]], %[[BCAST16]] : vector<8x16xindex>
+// CHECK-DAG:   LOAD_GATHER: %[[MULI1:.+]] = arith.muli %arg1, %[[C128]] : index
+// CHECK-DAG:   LOAD_GATHER: %[[MULI2:.+]] = arith.muli %arg2, %[[C16]] : index
+// CHECK-DAG:   LOAD_GATHER: %[[ADDI1:.+]] = arith.addi %[[MULI1]], %[[MULI2]] : index
+// CHECK-DAG:   LOAD_GATHER: %[[ADDI2:.+]] = arith.addi %[[ADDI1]], %arg3 : index
+// CHECK-DAG:   LOAD_GATHER: %[[BCASTIDX:.+]] = vector.broadcast %[[ADDI2]] : index to vector<8x16xindex>
+// CHECK-DAG:   LOAD_GATHER: %[[OFFSETS:.+]] = arith.addi %[[BCASTIDX]], %[[ADDIDX]] : vector<8x16xindex>
+// CHECK-DAG:   LOAD_GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %arg0 {{\[}}[0, 1, 2]{{\]}} : memref<?x8x16xf32> into memref<?xf32>
+// CHECK:       LOAD_GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[OFFSETS]]{{\]}}, %[[CST_0]] : memref<?xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32> 
+// CHECK:       LOAD_GATHER: return %[[VEC]] : vector<8x16xf32>
 
-// CHECK-LABEL: @load_dynamic_source(
-// CHECK-SAME:  %[[SRC:.+]]: memref<?x?x?xf32>,
-// CHECK-SAME:  %[[OFFSET:.+]]: index
-// CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
-// CHECK-DAG:   %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
-// CHECK-DAG:   %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
-// CHECK-DAG:   %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
-// CHECK:       %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
-// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
-// CHECK-SAME:  , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
-// CHECK-SAME:    memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32
-// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
-// CHECK:       return %[[VEC]]
 }
+
+// -----
+gpu.module @xevm_module {
+gpu.func @load_dynamic_source3(%source: memref<?x?x?x?x?xf32>,
+    %i: index, %j: index, %k: index, %l: index, %m: index) -> vector<2x4x8x16xf32> {
+  %c0 = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %source[%i, %j, %k, %l, %m], %c0
+    {in_bounds = [true, true, true, true]} : memref<?x?x?x?x?xf32>, vector<2x4x8x16xf32>
+  gpu.return %0 : vector<2x4x8x16xf32>
+}
+
+// CHECK-LABEL: LOAD_ND: @load_dynamic_source3(
+// CHECK:       LOAD_ND: vector.transfer_read
+
+// CHECK-LABEL: LOAD_GATHER: @load_dynamic_source3(
+// CHECK-SAME:  LOAD_GATHER: %[[SRC:.+]]: memref<?x?x?x?x?xf32>, %[[I0:.+]]: index, %[[I1:.+]]: index, %[[I2:.+]]: index, %[[I3:.+]]: index, %[[I4:.+]]: index
+// CHECK:       LOAD_GATHER: %[[CST:.+]] = arith.constant dense<true> : vector<2x4x8x16xi1>
+// CHECK:       LOAD_GATHER: %[[BASE:.+]], %[[OFFSET:.+]], %[[SIZES:.*]], %[[strides:.*]] = memref.extract_strided_metadata %[[SRC]] : memref<?x?x?x?x?xf32> -> memref<f32>, index, index, index, index, index, index, index, index, index, index, index
+// CHECK:       LOAD_GATHER: %[[STEP0:.+]] = vector.step : vector<2xindex>
+// CHECK:       LOAD_GATHER: %[[STEP1:.+]] = vector.step : vector<4xindex>
+// CHECK:       LOAD_GATHER: %[[STEP2:.+]] = vector.step : vector<8xindex>
+// CHECK:       LOAD_GATHER: %[[STEP3:.+]] = vector.step : vector<16xindex>
+// CHECK:       LOAD_GATHER: %[[BROAD0:.+]] = vector.broadcast %[[strides1:.+]] : index to vector<2xindex>
+// CHECK:       LOAD_GATHER: %[[MUL0:.+]] = arith.muli %[[STEP0]], %[[BROAD0]] : vector<2xindex>
+// CHECK:       LOAD_GATHER: %[[BROAD1:.+]] = vector.broadcast %[[strides2:.+]] : index to vector<4xindex>
+// CHECK:       LOAD_GATHER: %[[MUL1:.+]] = arith.muli %[[STEP1]], %[[BROAD1]] : vector<4xindex>
+// CHECK:       LOAD_GATHER: %[[BROAD2:.+]] = vector.broadcast %[[strides3:.+]] : index to vector<8xindex>
+// CHECK:       LOAD_GATHER: %[[MUL2:.+]] = arith.muli %[[STEP2]], %[[BROAD2]] : vector<8xindex>
+// CHECK:       LOAD_GATHER: %[[SHAPE0:.+]] = vector.shape_cast %[[MUL0]] : vector<2xindex> to vector<2x1x1x1xindex>
+// CHECK:       LOAD_GATHER: %[[SHAPE1:.+]] = vector.shape_cast %[[MUL1]] : vector<4xindex> to vector<1x4x1x1xindex>
+// CHECK:       LOAD_GATHER: %[[SHAPE2:.+]] = vector.shape_cast %[[MUL2]] : vector<8xindex> to vector<1x1x8x1xindex>
+// CHECK:       LOAD_GATHER: %[[SHAPE3:.+]] = vector.shape_cast %[[STEP3]] : vector<16xindex> to vector<1x1x1x16xindex>
+// CHECK:       LOAD_GATHER: %[[BC0:.+]] = vector.broadcast %[[SHAPE0]] : vector<2x1x1x1xindex> to vector<2x4x8x16xindex>
+// CHECK:       LOAD_GATHER: %[[BC1:.+]] = vector.broadcast %[[SHAPE1]] : vector<1x4x1x1xindex> to vector<2x4x8x16xindex>
+// CHECK:       LOAD_GATHER: %[[BC2:.+]] = vector.broadcast %[[SHAPE2]] : vector<1x1x8x1xindex> to vector<2x4x8x16xindex>
+// CHECK:       LOAD_GATHER: %[[BC3:.+]] = vector.broadcast %[[SHAPE3]] : vector<1x1x1x16xindex> to vector<2x4x8x16xindex>
+// CHECK:       LOAD_GATHER: %[[ADD0:.+]] = arith.addi %[[BC0]], %[[BC1]] : vector<2x4x8x16xindex>
+// CHECK:       LOAD_GATHER: %[[ADD1:.+]] = arith.addi %[[ADD0]], %[[BC2]] : vector<2x4x8x16xindex>
+// CHECK:       LOAD_GATHER: %[[ADD2:.+]] = arith.addi %[[ADD1]], %[[BC3]] : vector<2x4x8x16xindex>
+// CHECK:       LOAD_GATHER: %[[MULI0:.+]] = arith.muli %[[I0]], %[[strides0:.+]] : index
+// CHECK:       LOAD_GATHER: %[[MULI1:.+]] = arith.muli %[[I1]], %[[strides1:.+]] : index
+// CHECK:       LOAD_GATHER: %[[ADDI0:.+]] = arith.addi %[[MULI0]], %[[MULI1]] : index
+// CHECK:       LOAD_GATHER: %[[MULI2:.+]] = arith.muli %[[I2]], %[[strides2:.+]] : index
+// CHECK:       LOAD_GATHER: %[[ADDI1:.+]] = arith.addi %[[ADDI0]], %[[MULI2]] : index
+// CHECK:       LOAD_GATHER: %[[MULI3:.+]] = arith.muli %[[I3]], %[[strides4:.+]] : index
+// CHECK:       LOAD_GATHER: %[[ADDI2:.+]] = arith.addi %[[ADDI1]], %[[MULI3]] : index
+// CHECK:       LOAD_GATHER: %[[ADDI3:.+]] = arith.addi %[[ADDI2]], %[[I4]] : index
+// CHECK:       LOAD_GATHER: %[[SPLAT:.+]] = vector.broadcast %[[ADDI3]] : index to vector<2x4x8x16xindex>
+// CHECK:       LOAD_GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[ADD2]] : vector<2x4x8x16xindex>
+// CHECK:       LOAD_GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2, 3, 4]{{\]}} : memref<?x?x?x?x?xf32> into memref<?xf32>
+// CHECK:       LOAD_GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<?xf32>, vector<2x4x8x16xindex>, vector<2x4x8x16xi1> -> vector<2x4x8x16xf32>
+// CHECK:       LOAD_GATHER: return %[[VEC]]
+}
+
 // -----
 gpu.module @xevm_module {
 gpu.func @no_load_out_of_bounds_non_zero_pad(%source: memref<32x64xf32>,

>From 87a79ad8b28e4c2182827a7261a1c240b8251454 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Sat, 2 Aug 2025 05:51:38 +0000
Subject: [PATCH 05/17] add tests

---
 .../VectorToXeGPU/transfer-read-to-xegpu.mlir | 128 ++++++++++++------
 1 file changed, 89 insertions(+), 39 deletions(-)

diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
index b880a70187353..aaf0f53258c9b 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
@@ -23,14 +23,14 @@ gpu.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vector
 // CHECK-SAME:  LOAD_GATHER: %[[SRC:.+]]: memref<8x16x32xf32>,
 // CHECK-SAME:  LOAD_GATHER: %[[OFFSET:.+]]: index
 // CHECK:       LOAD_GATHER: %[[CST:.+]] = arith.constant dense<true> : vector<8xi1>
-// CHECK:       LOAD_GATHER: %[[C32:.+]] = arith.constant 32 : index
 // CHECK:       LOAD_GATHER: %[[C512:.+]] = arith.constant 512 : index
+// CHECK:       LOAD_GATHER: %[[C32:.+]] = arith.constant 32 : index
 // CHECK:       LOAD_GATHER: %[[STEP:.+]] = vector.step : vector<8xindex>
 // CHECK:       LOAD_GATHER: %[[MUL1:.+]] = arith.muli %[[OFFSET]], %[[C512]] : index
 // CHECK:       LOAD_GATHER: %[[MUL2:.+]] = arith.muli %[[OFFSET]], %[[C32]] : index
 // CHECK:       LOAD_GATHER: %[[ADD1:.+]] = arith.addi %[[MUL1]], %[[MUL2]] : index
 // CHECK:       LOAD_GATHER: %[[ADD2:.+]] = arith.addi %[[ADD1]], %[[OFFSET]] : index
-// CHECK:       LOAD_GATHER: %[[SPLAT:.+]] = vector.splat %[[ADD2]] : vector<8xindex>
+// CHECK:       LOAD_GATHER: %[[SPLAT:.+]] = vector.broadcast %[[ADD2]] :  index to vector<8xindex>
 // CHECK:       LOAD_GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[STEP]] : vector<8xindex>
 // CHECK:       LOAD_GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
 // CHECK:       LOAD_GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<4096xf32>, vector<8xindex>, vector<8xi1> -> vector<8xf32>
@@ -62,9 +62,9 @@ gpu.func @load_2D_vector(%source: memref<8x16x32xf32>,
 // CHECK-SAME:  LOAD_GATHER: %[[SRC:.+]]: memref<8x16x32xf32>,
 // CHECK-SAME:  LOAD_GATHER: %[[OFFSET:.+]]: index
 // CHECK:       LOAD_GATHER: %[[CST:.+]] = arith.constant dense<true> : vector<8x16xi1>
-// CHECK:       LOAD_GATHER: %[[C32:.+]] = arith.constant 32 : index
-// CHECK:       LOAD_GATHER: %[[C512:.+]] = arith.constant 512 : index
 // CHECK:       LOAD_GATHER: %[[CST_0:.+]] = arith.constant dense<32> : vector<8xindex>
+// CHECK:       LOAD_GATHER: %[[C512:.+]] = arith.constant 512 : index
+// CHECK:       LOAD_GATHER: %[[C32:.+]] = arith.constant 32 : index
 // CHECK:       LOAD_GATHER: %[[STEP0:.+]] = vector.step : vector<8xindex>
 // CHECK:       LOAD_GATHER: %[[STEP1:.+]] = vector.step : vector<16xindex>
 // CHECK:       LOAD_GATHER: %[[MUL:.+]] = arith.muli %[[STEP0]], %[[CST_0]] : vector<8xindex>
@@ -77,13 +77,14 @@ gpu.func @load_2D_vector(%source: memref<8x16x32xf32>,
 // CHECK:       LOAD_GATHER: %[[MUL2:.+]] = arith.muli %[[OFFSET]], %[[C32]] : index
 // CHECK:       LOAD_GATHER: %[[ADD1:.+]] = arith.addi %[[MUL1]], %[[MUL2]] : index
 // CHECK:       LOAD_GATHER: %[[ADD2:.+]] = arith.addi %[[ADD1]], %[[OFFSET]] : index
-// CHECK:       LOAD_GATHER: %[[SPLAT:.+]] = vector.splat %[[ADD2]] : vector<8x16xindex>
+// CHECK:       LOAD_GATHER: %[[SPLAT:.+]] = vector.broadcast %[[ADD2]] : index to vector<8x16xindex>
 // CHECK:       LOAD_GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[ADD_VEC]] : vector<8x16xindex>
 // CHECK:       LOAD_GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
 // CHECK:       LOAD_GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<4096xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
 // CHECK:       LOAD_GATHER: return %[[VEC]]
 }
 
+
 // -----
 gpu.module @xevm_module {
 gpu.func @load_zero_pad_out_of_bounds(%source: memref<32x64xf32>,
@@ -94,34 +95,58 @@ gpu.func @load_zero_pad_out_of_bounds(%source: memref<32x64xf32>,
   gpu.return %0 : vector<8x16xf32>
 }
 
-// CHECK-LABEL: @load_zero_pad_out_of_bounds(
-// CHECK-SAME:  %[[SRC:.+]]: memref<32x64xf32>,
-// CHECK-SAME:  %[[OFFSET:.+]]: index
-// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
-// CHECK-SAME:    memref<32x64xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
-// CHECK:       return %[[VEC]]
+// CHECK-LABEL: LOAD_ND: @load_zero_pad_out_of_bounds(
+// CHECK-SAME:  LOAD_ND: %[[SRC:.+]]: memref<32x64xf32>,
+// CHECK-SAME:  LOAD_ND: %[[OFFSET:.+]]: index
+// CHECK:       LOAD_ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
+// CHECK-SAME:  LOAD_ND:   memref<32x64xf32> -> !xegpu.tensor_desc<8x16xf32>
+// CHECK:       LOAD_ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// CHECK:       LOAD_ND: return %[[VEC]]
+
+// CHECK-LABEL: LOAD_GATHER: @load_zero_pad_out_of_bounds(
+// CHECK:       LOAD_GATHER: vector.transfer_read
+
 }
 
+
 // -----
 gpu.module @xevm_module {
 gpu.func @load_transposed(%source: memref<32x64xf32>,
-    %offset: index) -> vector<8x16xf32> {
+    %i: index, %j: index) -> vector<8x16xf32> {
   %c0 = arith.constant 0.0 : f32
-  %0 = vector.transfer_read %source[%offset, %offset], %c0
+  %0 = vector.transfer_read %source[%i, %j], %c0
     {permutation_map = affine_map<(d0, d1) -> (d1, d0)>,
     in_bounds = [true, true]} : memref<32x64xf32>, vector<8x16xf32>
   gpu.return %0 : vector<8x16xf32>
 }
 
-// CHECK-LABEL: @load_transposed(
-// CHECK-SAME:  %[[SRC:.+]]: memref<32x64xf32>,
-// CHECK-SAME:  %[[OFFSET:.+]]: index
-// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
-// CHECK-SAME:    memref<32x64xf32> -> !xegpu.tensor_desc<16x8xf32
-// CHECK:       %[[VEC:.+]] = xegpu.load_nd %[[DESC]] <{transpose = array<i64: 1, 0>}>
-// CHECK-SAME:    -> vector<8x16xf32>
-// CHECK:       return %[[VEC]]
+// CHECK-LABEL: LOAD_ND: @load_transposed(%[[SRC:.+]]: memref<32x64xf32>, %[[OFFSET1:.+]]: index, %[[OFFSET2:.+]]: index
+// CHECK:       LOAD_ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET1]], %[[OFFSET2]]]
+// CHECK-SAME:  LOAD_ND:   memref<32x64xf32> -> !xegpu.tensor_desc<16x8xf32
+// CHECK:       LOAD_ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]] <{transpose = array<i64: 1, 0>}>
+// CHECK-SAME:  LOAD_ND:   -> vector<8x16xf32>
+// CHECK:       LOAD_ND: return %[[VEC]]
+
+
+// CHECK-LABEL: LOAD_GATHER: @load_transposed(
+// CHECK:       LOAD_GATHER: %[[CST:.*]] = arith.constant dense<true> : vector<8x16xi1>
+// CHECK:       LOAD_GATHER: %[[CST_0:.*]] = arith.constant dense<64> : vector<16xindex>
+// CHECK:       LOAD_GATHER: %[[C64:.*]] = arith.constant 64 : index
+// CHECK:       LOAD_GATHER: %[[STEP0:.*]] = vector.step : vector<8xindex>
+// CHECK:       LOAD_GATHER: %[[STEP1:.*]] = vector.step : vector<16xindex>
+// CHECK:       LOAD_GATHER: %[[MUL1:.*]] = arith.muli %[[STEP1]], %[[CST_0]] : vector<16xindex>
+// CHECK:       LOAD_GATHER: %[[SHAPE0:.*]] = vector.shape_cast %[[STEP0]] : vector<8xindex> to vector<8x1xindex>
+// CHECK:       LOAD_GATHER: %[[SHAPE1:.*]] = vector.shape_cast %[[MUL1]] : vector<16xindex> to vector<1x16xindex>
+// CHECK:       LOAD_GATHER: %[[BCAST0:.*]] = vector.broadcast %[[SHAPE0]] : vector<8x1xindex> to vector<8x16xindex>
+// CHECK:       LOAD_GATHER: %[[BCAST1:.*]] = vector.broadcast %[[SHAPE1]] : vector<1x16xindex> to vector<8x16xindex>
+// CHECK:       LOAD_GATHER: %[[ADD1:.*]] = arith.addi %[[BCAST0]], %[[BCAST1]] : vector<8x16xindex>
+// CHECK:       LOAD_GATHER: %[[MUL2:.*]] = arith.muli %arg2, %[[C64]] : index
+// CHECK:       LOAD_GATHER: %[[ADD2:.*]] = arith.addi %arg1, %[[MUL2]] : index
+// CHECK:       LOAD_GATHER: %[[BCAST2:.*]] = vector.broadcast %[[ADD2]] : index to vector<8x16xindex>
+// CHECK:       LOAD_GATHER: %[[ADD3:.*]] = arith.addi %[[BCAST2]], %[[ADD1]] : vector<8x16xindex>
+// CHECK:       LOAD_GATHER: %[[COLLAPSE:.*]] = memref.collapse_shape %arg0 {{\[\[}}0, 1{{\]\]}} : memref<32x64xf32> into memref<2048xf32>
+// CHECK:       LOAD_GATHER: %[[LOAD:.*]] = xegpu.load %[[COLLAPSE]][%[[ADD3]]], %[[CST]] : memref<2048xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
+// CHECK:       LOAD_GATHER: gpu.return %[[LOAD]] : vector<8x16xf32>
 }
 
 // -----
@@ -278,8 +303,11 @@ gpu.func @no_load_out_of_bounds_non_zero_pad(%source: memref<32x64xf32>,
   gpu.return %0, %1 : vector<8x16xf32>, vector<8x16xf32>
 }
 
-// CHECK-LABEL:   @no_load_out_of_bounds_non_zero_pad(
-// CHECK-COUNT-2: vector.transfer_read
+// CHECK-LABEL:   LOAD_ND: @no_load_out_of_bounds_non_zero_pad(
+// CHECK-COUNT-2: LOAD_ND: vector.transfer_read
+
+// CHECK-LABEL:   LOAD_GATHER: @no_load_out_of_bounds_non_zero_pad(
+// CHECK-COUNT-2: LOAD_GATHER: vector.transfer_read
 }
 
 // -----
@@ -292,9 +320,13 @@ gpu.func @no_load_out_of_bounds_1D_vector(%source: memref<8x16x32xf32>,
   gpu.return %0 : vector<8xf32>
 }
 
-// CHECK-LABEL: @no_load_out_of_bounds_1D_vector(
-// CHECK:       vector.transfer_read
+// CHECK-LABEL: LOAD_ND: @no_load_out_of_bounds_1D_vector(
+// CHECK:       LOAD_ND: vector.transfer_read
+
+// CHECK-LABEL: LOAD_GATHER: @no_load_out_of_bounds_1D_vector(
+// CHECK:       LOAD_GATHER: vector.transfer_read
 }
+
 // -----
 gpu.module @xevm_module {
 gpu.func @no_load_masked(%source : memref<4xf32>,
@@ -306,8 +338,11 @@ gpu.func @no_load_masked(%source : memref<4xf32>,
   gpu.return %0 : vector<4xf32>
 }
 
-// CHECK-LABEL: @no_load_masked(
-// CHECK:       vector.transfer_read
+// CHECK-LABEL: LOAD_ND: @no_load_masked(
+// CHECK:       LOAD_ND: vector.transfer_read
+
+// CHECK-LABEL: LOAD_GATHER: @no_load_masked(
+// CHECK:       LOAD_GATHER: vector.transfer_read
 }
 
 // -----
@@ -320,8 +355,11 @@ gpu.func @no_load_tensor(%source: tensor<32x64xf32>,
   gpu.return %0 : vector<8x16xf32>
 }
 
-// CHECK-LABEL: @no_load_tensor(
-// CHECK:       vector.transfer_read
+// CHECK-LABEL: LOAD_ND: @no_load_tensor(
+// CHECK:       LOAD_ND: vector.transfer_read
+
+// CHECK-LABEL: LOAD_GATHER: @no_load_tensor(
+// CHECK:       LOAD_GATHER: vector.transfer_read
 }
 
 // -----
@@ -334,8 +372,11 @@ gpu.func @no_load_high_dim_vector(%source: memref<16x32x64xf32>,
   gpu.return %0 : vector<8x16x32xf32>
 }
 
-// CHECK-LABEL: @no_load_high_dim_vector(
-// CHECK:       vector.transfer_read
+// CHECK-LABEL: LOAD_ND: @no_load_high_dim_vector(
+// CHECK:       LOAD_ND: vector.transfer_read
+
+// CHECK-LABEL: LOAD_GATHER: @no_load_high_dim_vector(
+// CHECK:       LOAD_GATHER: vector.transfer_read
 }
 
 // -----
@@ -349,10 +390,14 @@ gpu.func @no_load_non_unit_inner_stride(
   gpu.return %0 : vector<8xf32>
 }
 
-// CHECK-LABEL: @no_load_non_unit_inner_stride(
-// CHECK:       vector.transfer_read
+// CHECK-LABEL: LOAD_ND: @no_load_non_unit_inner_stride(
+// CHECK:       LOAD_ND: vector.transfer_read
+
+// CHECK-LABEL: LOAD_GATHER: @no_load_non_unit_inner_stride(
+// CHECK:       LOAD_GATHER: vector.transfer_read
 }
 
+
 // -----
 gpu.module @xevm_module {
 gpu.func @no_load_unsupported_map(%source: memref<16x32x64xf32>,
@@ -364,8 +409,11 @@ gpu.func @no_load_unsupported_map(%source: memref<16x32x64xf32>,
   gpu.return %0 : vector<8x16xf32>
 }
 
-// CHECK-LABEL: @no_load_unsupported_map(
-// CHECK:       vector.transfer_read
+// CHECK-LABEL: LOAD_ND: @no_load_unsupported_map(
+// CHECK:       LOAD_ND: vector.transfer_read
+
+// CHECK-LABEL: LOAD_GATHER: @no_load_unsupported_map(
+// CHECK:       LOAD_GATHER: vector.transfer_read
 }
 
 // -----
@@ -379,7 +427,9 @@ gpu.func @no_load_transpose_unsupported_data_type(%source: memref<32x64xf16>,
   gpu.return %0 : vector<8x16xf16>
 }
 
-// CHECK-LABEL: @no_load_transpose_unsupported_data_type(
-// CHECK:       vector.transfer_read
-}
+// CHECK-LABEL: LOAD_ND: @no_load_transpose_unsupported_data_type(
+// CHECK:       LOAD_ND: vector.transfer_read
 
+// CHECK-LABEL: LOAD_GATHER: @no_load_transpose_unsupported_data_type(
+// CHECK:       LOAD_GATHER: vector.transfer_read
+}

>From e07f3c13d1e598c3838e26786af68560c05d3dde Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 6 Aug 2025 18:41:38 +0000
Subject: [PATCH 06/17] add tests

---
 .../VectorToXeGPU/transfer-read-to-xegpu.mlir | 418 ++++++++----------
 1 file changed, 183 insertions(+), 235 deletions(-)

diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
index aaf0f53258c9b..eebdb02ee2354 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
@@ -9,32 +9,26 @@ gpu.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vector
   gpu.return %0 : vector<8xf32>
 }
 
-// CHECK-LABEL: LOAD_ND: @load_1D_vector(
-// CHECK-SAME:  LOAD_ND: %[[SRC:.+]]: memref<8x16x32xf32>,
-// CHECK-SAME:  LOAD_ND: %[[OFFSET:.+]]: index
-// CHECK:       LOAD_ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME:  LOAD_ND:   %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
-// CHECK-SAME:  LOAD_ND:   memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
-// CHECK-SAME:  LOAD_ND:   boundary_check = false
-// CHECK:       LOAD_ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf32>
-// CHECK:       LOAD_ND: return %[[VEC]]
-
-// CHECK-LABEL: LOAD_GATHER: @load_1D_vector(
-// CHECK-SAME:  LOAD_GATHER: %[[SRC:.+]]: memref<8x16x32xf32>,
-// CHECK-SAME:  LOAD_GATHER: %[[OFFSET:.+]]: index
-// CHECK:       LOAD_GATHER: %[[CST:.+]] = arith.constant dense<true> : vector<8xi1>
-// CHECK:       LOAD_GATHER: %[[C512:.+]] = arith.constant 512 : index
-// CHECK:       LOAD_GATHER: %[[C32:.+]] = arith.constant 32 : index
-// CHECK:       LOAD_GATHER: %[[STEP:.+]] = vector.step : vector<8xindex>
-// CHECK:       LOAD_GATHER: %[[MUL1:.+]] = arith.muli %[[OFFSET]], %[[C512]] : index
-// CHECK:       LOAD_GATHER: %[[MUL2:.+]] = arith.muli %[[OFFSET]], %[[C32]] : index
-// CHECK:       LOAD_GATHER: %[[ADD1:.+]] = arith.addi %[[MUL1]], %[[MUL2]] : index
-// CHECK:       LOAD_GATHER: %[[ADD2:.+]] = arith.addi %[[ADD1]], %[[OFFSET]] : index
-// CHECK:       LOAD_GATHER: %[[SPLAT:.+]] = vector.broadcast %[[ADD2]] :  index to vector<8xindex>
-// CHECK:       LOAD_GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[STEP]] : vector<8xindex>
-// CHECK:       LOAD_GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
-// CHECK:       LOAD_GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<4096xf32>, vector<8xindex>, vector<8xi1> -> vector<8xf32>
-// CHECK:       LOAD_GATHER: return %[[VEC]]
+// LOAD_ND-LABEL:  @load_1D_vector(
+// LOAD_ND-SAME:   %[[SRC:.+]]: memref<8x16x32xf32>,
+// LOAD_ND-SAME:   %[[OFFSET:.+]]: index
+// LOAD_ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc
+// LOAD_ND-SAME:     %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// LOAD_ND-SAME:     memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
+// LOAD_ND-SAME:     boundary_check = false
+// LOAD_ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf32>
+// LOAD_ND:        return %[[VEC]]
+
+// LOAD_GATHER-LABEL:  @load_1D_vector(
+// LOAD_GATHER-SAME:   %[[SRC:.+]]: memref<8x16x32xf32>,
+// LOAD_GATHER:        %[[CST:.+]] = arith.constant dense<true> : vector<8xi1>
+// LOAD_GATHER:        %[[STEP:.+]] = vector.step : vector<8xindex>
+// LOAD_GATHER-COUNT2: arith.muli {{.*}} : index
+// LOAD_GATHER-COUNT2: arith.addi {{.*}} : index
+// LOAD_GATHER:        %[[SPLAT:.+]] = vector.broadcast {{.*}}:  index to vector<8xindex>
+// LOAD_GATHER:        %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[STEP]] : vector<8xindex>
+// LOAD_GATHER:        %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
+// LOAD_GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<4096xf32>, vector<8xindex>, vector<8xi1> -> vector<8xf32>
 
 }
 
@@ -48,40 +42,29 @@ gpu.func @load_2D_vector(%source: memref<8x16x32xf32>,
   gpu.return %0 : vector<8x16xf32>
 }
 
-// CHECK-LABEL: LOAD_ND: @load_2D_vector(
-// CHECK-SAME:  LOAD_ND: %[[SRC:.+]]: memref<8x16x32xf32>,
-// CHECK-SAME:  LOAD_ND: %[[OFFSET:.+]]: index
-// CHECK:       LOAD_ND: %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME:  LOAD_ND:   %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
-// CHECK-SAME:  LOAD_ND:   memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
-// CHECK-SAME:  LOAD_ND:   boundary_check = false
-// CHECK:       LOAD_ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
-// CHECK:       LOAD_ND: return %[[VEC]]
-
-// CHECK-LABEL: LOAD_GATHER: @load_2D_vector(
-// CHECK-SAME:  LOAD_GATHER: %[[SRC:.+]]: memref<8x16x32xf32>,
-// CHECK-SAME:  LOAD_GATHER: %[[OFFSET:.+]]: index
-// CHECK:       LOAD_GATHER: %[[CST:.+]] = arith.constant dense<true> : vector<8x16xi1>
-// CHECK:       LOAD_GATHER: %[[CST_0:.+]] = arith.constant dense<32> : vector<8xindex>
-// CHECK:       LOAD_GATHER: %[[C512:.+]] = arith.constant 512 : index
-// CHECK:       LOAD_GATHER: %[[C32:.+]] = arith.constant 32 : index
-// CHECK:       LOAD_GATHER: %[[STEP0:.+]] = vector.step : vector<8xindex>
-// CHECK:       LOAD_GATHER: %[[STEP1:.+]] = vector.step : vector<16xindex>
-// CHECK:       LOAD_GATHER: %[[MUL:.+]] = arith.muli %[[STEP0]], %[[CST_0]] : vector<8xindex>
-// CHECK:       LOAD_GATHER: %[[SHAPE0:.+]] = vector.shape_cast %[[MUL]] : vector<8xindex> to vector<8x1xindex>
-// CHECK:       LOAD_GATHER: %[[SHAPE1:.+]] = vector.shape_cast %[[STEP1]] : vector<16xindex> to vector<1x16xindex>
-// CHECK:       LOAD_GATHER: %[[BROADCAST0:.+]] = vector.broadcast %[[SHAPE0]] : vector<8x1xindex> to vector<8x16xindex>
-// CHECK:       LOAD_GATHER: %[[BROADCAST1:.+]] = vector.broadcast %[[SHAPE1]] : vector<1x16xindex> to vector<8x16xindex>
-// CHECK:       LOAD_GATHER: %[[ADD_VEC:.+]] = arith.addi %[[BROADCAST0]], %[[BROADCAST1]] : vector<8x16xindex>
-// CHECK:       LOAD_GATHER: %[[MUL1:.+]] = arith.muli %[[OFFSET]], %[[C512]] : index
-// CHECK:       LOAD_GATHER: %[[MUL2:.+]] = arith.muli %[[OFFSET]], %[[C32]] : index
-// CHECK:       LOAD_GATHER: %[[ADD1:.+]] = arith.addi %[[MUL1]], %[[MUL2]] : index
-// CHECK:       LOAD_GATHER: %[[ADD2:.+]] = arith.addi %[[ADD1]], %[[OFFSET]] : index
-// CHECK:       LOAD_GATHER: %[[SPLAT:.+]] = vector.broadcast %[[ADD2]] : index to vector<8x16xindex>
-// CHECK:       LOAD_GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[ADD_VEC]] : vector<8x16xindex>
-// CHECK:       LOAD_GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
-// CHECK:       LOAD_GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<4096xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
-// CHECK:       LOAD_GATHER: return %[[VEC]]
+// LOAD_ND-LABEL:  @load_2D_vector(
+// LOAD_ND-SAME:   %[[SRC:.+]]: memref<8x16x32xf32>,
+// LOAD_ND-SAME:   %[[OFFSET:.+]]: index
+// LOAD_ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc
+// LOAD_ND-SAME:     %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// LOAD_ND-SAME:     memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
+// LOAD_ND-SAME:     boundary_check = false
+// LOAD_ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// LOAD_ND:        return %[[VEC]]
+
+// LOAD_GATHER-LABEL:  @load_2D_vector(
+// LOAD_GATHER-SAME:   %[[SRC:.+]]: memref<8x16x32xf32>,
+// LOAD_GATHER:        %[[CST:.+]] = arith.constant dense<true> : vector<8x16xi1>
+// LOAD_GATHER-COUNT2: vector.step
+// LOAD_GATHER-COUNT2: vector.shape_cast
+// LOAD_GATHER-COUNT2: vector.broadcast
+// LOAD_GATHER-COUNT2: arith.muli {{.*}} : index
+// LOAD_GATHER-COUNT2: arith.addi {{.*}} : index
+// LOAD_GATHER:        %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8x16xindex>
+// LOAD_GATHER:        %[[IDX:.+]] = arith.addi %[[SPLAT]], {{.*}}: vector<8x16xindex>
+// LOAD_GATHER:        %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
+// LOAD_GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<4096xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
+
 }
 
 
@@ -95,16 +78,16 @@ gpu.func @load_zero_pad_out_of_bounds(%source: memref<32x64xf32>,
   gpu.return %0 : vector<8x16xf32>
 }
 
-// CHECK-LABEL: LOAD_ND: @load_zero_pad_out_of_bounds(
-// CHECK-SAME:  LOAD_ND: %[[SRC:.+]]: memref<32x64xf32>,
-// CHECK-SAME:  LOAD_ND: %[[OFFSET:.+]]: index
-// CHECK:       LOAD_ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
-// CHECK-SAME:  LOAD_ND:   memref<32x64xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK:       LOAD_ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
-// CHECK:       LOAD_ND: return %[[VEC]]
+// LOAD_ND-LABEL:  @load_zero_pad_out_of_bounds(
+// LOAD_ND-SAME:   %[[SRC:.+]]: memref<32x64xf32>,
+// LOAD_ND-SAME:   %[[OFFSET:.+]]: index
+// LOAD_ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
+// LOAD_ND-SAME:     memref<32x64xf32> -> !xegpu.tensor_desc<8x16xf32>
+// LOAD_ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// LOAD_ND:        return %[[VEC]]
 
-// CHECK-LABEL: LOAD_GATHER: @load_zero_pad_out_of_bounds(
-// CHECK:       LOAD_GATHER: vector.transfer_read
+// LOAD_GATHER-LABEL:  @load_zero_pad_out_of_bounds(
+// LOAD_GATHER:        vector.transfer_read
 
 }
 
@@ -120,33 +103,30 @@ gpu.func @load_transposed(%source: memref<32x64xf32>,
   gpu.return %0 : vector<8x16xf32>
 }
 
-// CHECK-LABEL: LOAD_ND: @load_transposed(%[[SRC:.+]]: memref<32x64xf32>, %[[OFFSET1:.+]]: index, %[[OFFSET2:.+]]: index
-// CHECK:       LOAD_ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET1]], %[[OFFSET2]]]
-// CHECK-SAME:  LOAD_ND:   memref<32x64xf32> -> !xegpu.tensor_desc<16x8xf32
-// CHECK:       LOAD_ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]] <{transpose = array<i64: 1, 0>}>
-// CHECK-SAME:  LOAD_ND:   -> vector<8x16xf32>
-// CHECK:       LOAD_ND: return %[[VEC]]
-
-
-// CHECK-LABEL: LOAD_GATHER: @load_transposed(
-// CHECK:       LOAD_GATHER: %[[CST:.*]] = arith.constant dense<true> : vector<8x16xi1>
-// CHECK:       LOAD_GATHER: %[[CST_0:.*]] = arith.constant dense<64> : vector<16xindex>
-// CHECK:       LOAD_GATHER: %[[C64:.*]] = arith.constant 64 : index
-// CHECK:       LOAD_GATHER: %[[STEP0:.*]] = vector.step : vector<8xindex>
-// CHECK:       LOAD_GATHER: %[[STEP1:.*]] = vector.step : vector<16xindex>
-// CHECK:       LOAD_GATHER: %[[MUL1:.*]] = arith.muli %[[STEP1]], %[[CST_0]] : vector<16xindex>
-// CHECK:       LOAD_GATHER: %[[SHAPE0:.*]] = vector.shape_cast %[[STEP0]] : vector<8xindex> to vector<8x1xindex>
-// CHECK:       LOAD_GATHER: %[[SHAPE1:.*]] = vector.shape_cast %[[MUL1]] : vector<16xindex> to vector<1x16xindex>
-// CHECK:       LOAD_GATHER: %[[BCAST0:.*]] = vector.broadcast %[[SHAPE0]] : vector<8x1xindex> to vector<8x16xindex>
-// CHECK:       LOAD_GATHER: %[[BCAST1:.*]] = vector.broadcast %[[SHAPE1]] : vector<1x16xindex> to vector<8x16xindex>
-// CHECK:       LOAD_GATHER: %[[ADD1:.*]] = arith.addi %[[BCAST0]], %[[BCAST1]] : vector<8x16xindex>
-// CHECK:       LOAD_GATHER: %[[MUL2:.*]] = arith.muli %arg2, %[[C64]] : index
-// CHECK:       LOAD_GATHER: %[[ADD2:.*]] = arith.addi %arg1, %[[MUL2]] : index
-// CHECK:       LOAD_GATHER: %[[BCAST2:.*]] = vector.broadcast %[[ADD2]] : index to vector<8x16xindex>
-// CHECK:       LOAD_GATHER: %[[ADD3:.*]] = arith.addi %[[BCAST2]], %[[ADD1]] : vector<8x16xindex>
-// CHECK:       LOAD_GATHER: %[[COLLAPSE:.*]] = memref.collapse_shape %arg0 {{\[\[}}0, 1{{\]\]}} : memref<32x64xf32> into memref<2048xf32>
-// CHECK:       LOAD_GATHER: %[[LOAD:.*]] = xegpu.load %[[COLLAPSE]][%[[ADD3]]], %[[CST]] : memref<2048xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
-// CHECK:       LOAD_GATHER: gpu.return %[[LOAD]] : vector<8x16xf32>
+// LOAD_ND-LABEL:  @load_transposed(
+// LOAD_ND-SAME:   %[[SRC:.+]]: memref<32x64xf32>,
+// LOAD_ND-SAME:   %[[OFFSET1:.+]]: index, 
+// LOAD_ND-SAME:   %[[OFFSET2:.+]]: index  
+// LOAD_ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET1]], %[[OFFSET2]]]
+// LOAD_ND-SAME:     memref<32x64xf32> -> !xegpu.tensor_desc<16x8xf32
+// LOAD_ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]] <{transpose = array<i64: 1, 0>}>
+// LOAD_ND-SAME:     -> vector<8x16xf32>
+// LOAD_ND:        return %[[VEC]]
+
+
+// LOAD_GATHER-LABEL:  @load_transposed(
+// LOAD_GATHER-SAME:    %[[SRC:.+]]: memref<32x64xf32>,
+// LOAD_GATHER:         %[[CST:.+]] = arith.constant dense<true> : vector<8x16xi1>
+// LOAD_GATHER-COUNT2:  vector.step
+// LOAD_GATHER-COUNT2:  vector.shape_cast
+// LOAD_GATHER-COUNT2: vector.broadcast
+// LOAD_GATHER-COUNT2: arith.muli {{.*}} : index
+// LOAD_GATHER-COUNT2: arith.addi {{.*}} : index
+// LOAD_GATHER:        %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
+// LOAD_GATHER:        %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}}: vector<8x16xindex>
+// LOAD_GATHER:        %[[COLLAPSE:.*]] = memref.collapse_shape %arg0 {{\[\[}}0, 1{{\]\]}} : memref<32x64xf32> into memref<2048xf32>
+// LOAD_GATHER:        %[[LOAD:.*]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<2048xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
+
 }
 
 // -----
@@ -158,42 +138,35 @@ gpu.func @load_dynamic_source(%source: memref<?x?x?xf32>,
     {in_bounds = [true, true]} : memref<?x?x?xf32>, vector<8x16xf32>
   gpu.return %0 : vector<8x16xf32>
 }
-// CHECK-LABEL: LOAD_ND: @load_dynamic_source(
-// CHECK-SAME:  LOAD_ND: %[[SRC:.+]]: memref<?x?x?xf32>,
-// CHECK-SAME:  LOAD_ND: %[[OFFSET:.+]]: index
-// CHECK:       LOAD_ND: %[[C2:.+]] = arith.constant 2 : index
-// CHECK:       LOAD_ND: %[[C1:.+]] = arith.constant 1 : index
-// CHECK:       LOAD_ND: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG:   LOAD_ND: %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
-// CHECK-DAG:   LOAD_ND: %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
-// CHECK-DAG:   LOAD_ND: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
-// CHECK:       LOAD_ND: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
-// CHECK:       LOAD_ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET:.+]], %[[OFFSET:.+]], %[[OFFSET:.+]]]
-// CHECK:       LOAD_ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
-// CHECK:       LOAD_ND: return %[[VEC]]
-
-
-// CHECK-LABEL: LOAD_GATHER: @load_dynamic_source(%[[ARG0:.+]]: memref<?x?x?xf32>, %[[ARG1:.+]]: index, %[[ARG2:.+]]: index, %[[ARG3:.+]]: index)
-// CHECK:       LOAD_GATHER: %[[CST:.+]] = arith.constant dense<true> : vector<8x16xi1>
-// CHECK:       LOAD_GATHER: memref.extract_strided_metadata %[[ARG0]]
-// CHECK:       LOAD_GATHER: %[[STEP8:.+]] = vector.step : vector<8xindex>
-// CHECK:       LOAD_GATHER: %[[STEP16:.+]] = vector.step : vector<16xindex>
-// CHECK:       LOAD_GATHER: %[[BSTR1:.+]] = vector.broadcast %[[STR1:.+]] : index to vector<8xindex>
-// CHECK:       LOAD_GATHER: %[[MUL8:.+]] = arith.muli %[[STEP8]], %[[BSTR1]] : vector<8xindex>
-// CHECK:       LOAD_GATHER: %[[SHAPE8:.+]] = vector.shape_cast %[[MUL8]] : vector<8xindex> to vector<8x1xindex>
-// CHECK:       LOAD_GATHER: %[[SHAPE16:.+]] = vector.shape_cast %[[STEP16]] : vector<16xindex> to vector<1x16xindex>
-// CHECK:       LOAD_GATHER: %[[BROAD8:.+]] = vector.broadcast %[[SHAPE8]] : vector<8x1xindex> to vector<8x16xindex>
-// CHECK:       LOAD_GATHER: %[[BROAD16:.+]] = vector.broadcast %[[SHAPE16]] : vector<1x16xindex> to vector<8x16xindex>
-// CHECK:       LOAD_GATHER: %[[ADDVEC:.+]] = arith.addi %[[BROAD8]], %[[BROAD16]] : vector<8x16xindex>
-// CHECK:       LOAD_GATHER: %[[MULI1:.+]] = arith.muli %[[ARG1]], %[[STR0:.+]] : index
-// CHECK:       LOAD_GATHER: %[[MULI2:.+]] = arith.muli %[[ARG2]], %[[STR1]] : index
-// CHECK:       LOAD_GATHER: %[[ADDI1:.+]] = arith.addi %[[MULI1]], %[[MULI2]] : index
-// CHECK:       LOAD_GATHER: %[[ADDI2:.+]] = arith.addi %[[ADDI1]], %[[ARG3]] : index
-// CHECK:       LOAD_GATHER: %[[BROADIDX:.+]] = vector.broadcast %[[ADDI2]] : index to vector<8x16xindex>
-// CHECK:       LOAD_GATHER: %[[FINALIDX:.+]] = arith.addi %[[BROADIDX]], %[[ADDVEC]] : vector<8x16xindex>
-// CHECK:       LOAD_GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2]{{\]}} : memref<?x?x?xf32> into memref<?xf32>
-// CHECK:       LOAD_GATHER: %[[RES:.+]] = xegpu.load %[[COLLAPSE]][%[[FINALIDX]]], %[[CST]] : memref<?xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
-// CHECK:       LOAD_GATHER: gpu.return %[[RES]] : vector<8x16xf32>
+// LOAD_ND-LABEL:  @load_dynamic_source(
+// LOAD_ND-SAME:   %[[SRC:.+]]: memref<?x?x?xf32>,
+// LOAD_ND-SAME:   %[[OFFSET:.+]]: index
+// LOAD_ND:        %[[C2:.+]] = arith.constant 2 : index
+// LOAD_ND:        %[[C1:.+]] = arith.constant 1 : index
+// LOAD_ND:        %[[C0:.+]] = arith.constant 0 : index
+// LOAD_ND-DAG:    %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
+// LOAD_ND-DAG:    %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
+// LOAD_ND-DAG:    %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
+// LOAD_ND:        %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
+// LOAD_ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET:.+]], %[[OFFSET:.+]], %[[OFFSET:.+]]]
+// LOAD_ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// LOAD_ND:        return %[[VEC]]
+
+
+// LOAD_GATHER-LABEL:  @load_dynamic_source(
+// LOAD_GATHER-SAME:   %[[ARG0:.+]]: memref<?x?x?xf32>,
+// LOAD_GATHER:        %[[CST:.+]] = arith.constant dense<true> : vector<8x16xi1>
+// LOAD_GATHER:        memref.extract_strided_metadata %[[ARG0]]
+// LOAD_GATHER-COUNT2: vector.step
+// LOAD_GATHER-COUNT2: vector.shape_cast
+// LOAD_GATHER-COUNT2: vector.broadcast
+// LOAD_GATHER-COUNT2: arith.muli {{.*}} : index
+// LOAD_GATHER-COUNT2: arith.addi {{.*}} : index
+// LOAD_GATHER:        %[[BROADIDX:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
+// LOAD_GATHER:        %[[FINALIDX:.+]] = arith.addi %[[BROADIDX]], {{.*}} : vector<8x16xindex>
+// LOAD_GATHER:        %[[COLLAPSE:.+]] = memref.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2]{{\]}} : memref<?x?x?xf32> into memref<?xf32>
+// LOAD_GATHER:        %[[RES:.+]] = xegpu.load %[[COLLAPSE]][%[[FINALIDX]]], %[[CST]] : memref<?xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
+// LOAD_GATHER:        gpu.return %[[RES]] : vector<8x16xf32>
 }
 
 // -----
@@ -206,35 +179,24 @@ gpu.func @load_dynamic_source2(%source: memref<?x8x16xf32>,
   gpu.return %0 : vector<8x16xf32>
 }
 
-// CHECK-LABEL: LOAD_ND: @load_dynamic_source2(
-// CHECK-DAG:   LOAD_ND: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG:   LOAD_ND: %[[DIM:.+]] = memref.dim %{{.*}}, %[[C0]] : memref<?x8x16xf32>
-// CHECK:       LOAD_ND: %[[DESC:.+]] = xegpu.create_nd_tdesc %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}], shape : [%[[DIM]], 8, 16], strides : [128, 16, 1] : memref<?x8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
-// CHECK:       LOAD_ND: %[[VEC:.+]] = xegpu.load_nd %[[DESC]] : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>> -> vector<8x16xf32>
-// CHECK:       LOAD_ND: return %[[VEC]] : vector<8x16xf32>
-
-// CHECK-LABEL: LOAD_GATHER: @load_dynamic_source2(
-// CHECK-DAG:   LOAD_GATHER: %[[CST:.+]] = arith.constant dense<16> : vector<8xindex>
-// CHECK-DAG:   LOAD_GATHER: %[[CST_0:.+]] = arith.constant dense<true> : vector<8x16xi1>
-// CHECK-DAG:   LOAD_GATHER: %[[C128:.+]] = arith.constant 128 : index
-// CHECK-DAG:   LOAD_GATHER: %[[C16:.+]] = arith.constant 16 : index
-// CHECK-DAG:   LOAD_GATHER: %[[STEP8:.+]] = vector.step : vector<8xindex>
-// CHECK-DAG:   LOAD_GATHER: %[[STEP16:.+]] = vector.step : vector<16xindex>
-// CHECK-DAG:   LOAD_GATHER: %[[MUL8:.+]] = arith.muli %[[STEP8]], %[[CST]] : vector<8xindex>
-// CHECK-DAG:   LOAD_GATHER: %[[SHAPE8:.+]] = vector.shape_cast %[[MUL8]] : vector<8xindex> to vector<8x1xindex>
-// CHECK-DAG:   LOAD_GATHER: %[[SHAPE16:.+]] = vector.shape_cast %[[STEP16]] : vector<16xindex> to vector<1x16xindex>
-// CHECK-DAG:   LOAD_GATHER: %[[BCAST8:.+]] = vector.broadcast %[[SHAPE8]] : vector<8x1xindex> to vector<8x16xindex>
-// CHECK-DAG:   LOAD_GATHER: %[[BCAST16:.+]] = vector.broadcast %[[SHAPE16]] : vector<1x16xindex> to vector<8x16xindex>
-// CHECK-DAG:   LOAD_GATHER: %[[ADDIDX:.+]] = arith.addi %[[BCAST8]], %[[BCAST16]] : vector<8x16xindex>
-// CHECK-DAG:   LOAD_GATHER: %[[MULI1:.+]] = arith.muli %arg1, %[[C128]] : index
-// CHECK-DAG:   LOAD_GATHER: %[[MULI2:.+]] = arith.muli %arg2, %[[C16]] : index
-// CHECK-DAG:   LOAD_GATHER: %[[ADDI1:.+]] = arith.addi %[[MULI1]], %[[MULI2]] : index
-// CHECK-DAG:   LOAD_GATHER: %[[ADDI2:.+]] = arith.addi %[[ADDI1]], %arg3 : index
-// CHECK-DAG:   LOAD_GATHER: %[[BCASTIDX:.+]] = vector.broadcast %[[ADDI2]] : index to vector<8x16xindex>
-// CHECK-DAG:   LOAD_GATHER: %[[OFFSETS:.+]] = arith.addi %[[BCASTIDX]], %[[ADDIDX]] : vector<8x16xindex>
-// CHECK-DAG:   LOAD_GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %arg0 {{\[}}[0, 1, 2]{{\]}} : memref<?x8x16xf32> into memref<?xf32>
-// CHECK:       LOAD_GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[OFFSETS]]{{\]}}, %[[CST_0]] : memref<?xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32> 
-// CHECK:       LOAD_GATHER: return %[[VEC]] : vector<8x16xf32>
+// LOAD_ND-LABEL:  @load_dynamic_source2(
+// LOAD_ND-DAG:    %[[C0:.+]] = arith.constant 0 : index
+// LOAD_ND-DAG:    %[[DIM:.+]] = memref.dim %{{.*}}, %[[C0]] : memref<?x8x16xf32>
+// LOAD_ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}], shape : [%[[DIM]], 8, 16], strides : [128, 16, 1] : memref<?x8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
+// LOAD_ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]] : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>> -> vector<8x16xf32>
+// LOAD_ND:        return %[[VEC]] : vector<8x16xf32>
+
+// LOAD_GATHER-LABEL:  @load_dynamic_source2(
+// LOAD_GATHER-DAG:    %[[CST_0:.+]] = arith.constant dense<true> : vector<8x16xi1>
+// LOAD_GATHER-COUNT2: vector.step
+// LOAD_GATHER-COUNT2: vector.shape_cast
+// LOAD_GATHER-COUNT2: vector.broadcast
+// LOAD_GATHER-COUNT2: arith.muli {{.*}} : index
+// LOAD_GATHER-COUNT2: arith.addi {{.*}} : index
+// LOAD_GATHER-DAG:    %[[BCASTIDX:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
+// LOAD_GATHER-DAG:    %[[OFFSETS:.+]] = arith.addi %[[BCASTIDX]], {{.*}} : vector<8x16xindex>
+// LOAD_GATHER-DAG:    %[[COLLAPSE:.+]] = memref.collapse_shape %arg0 {{\[}}[0, 1, 2]{{\]}} : memref<?x8x16xf32> into memref<?xf32>
+// LOAD_GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[OFFSETS]]{{\]}}, %[[CST_0]] : memref<?xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32> 
 
 }
 
@@ -248,47 +210,23 @@ gpu.func @load_dynamic_source3(%source: memref<?x?x?x?x?xf32>,
   gpu.return %0 : vector<2x4x8x16xf32>
 }
 
-// CHECK-LABEL: LOAD_ND: @load_dynamic_source3(
-// CHECK:       LOAD_ND: vector.transfer_read
-
-// CHECK-LABEL: LOAD_GATHER: @load_dynamic_source3(
-// CHECK-SAME:  LOAD_GATHER: %[[SRC:.+]]: memref<?x?x?x?x?xf32>, %[[I0:.+]]: index, %[[I1:.+]]: index, %[[I2:.+]]: index, %[[I3:.+]]: index, %[[I4:.+]]: index
-// CHECK:       LOAD_GATHER: %[[CST:.+]] = arith.constant dense<true> : vector<2x4x8x16xi1>
-// CHECK:       LOAD_GATHER: %[[BASE:.+]], %[[OFFSET:.+]], %[[SIZES:.*]], %[[strides:.*]] = memref.extract_strided_metadata %[[SRC]] : memref<?x?x?x?x?xf32> -> memref<f32>, index, index, index, index, index, index, index, index, index, index, index
-// CHECK:       LOAD_GATHER: %[[STEP0:.+]] = vector.step : vector<2xindex>
-// CHECK:       LOAD_GATHER: %[[STEP1:.+]] = vector.step : vector<4xindex>
-// CHECK:       LOAD_GATHER: %[[STEP2:.+]] = vector.step : vector<8xindex>
-// CHECK:       LOAD_GATHER: %[[STEP3:.+]] = vector.step : vector<16xindex>
-// CHECK:       LOAD_GATHER: %[[BROAD0:.+]] = vector.broadcast %[[strides1:.+]] : index to vector<2xindex>
-// CHECK:       LOAD_GATHER: %[[MUL0:.+]] = arith.muli %[[STEP0]], %[[BROAD0]] : vector<2xindex>
-// CHECK:       LOAD_GATHER: %[[BROAD1:.+]] = vector.broadcast %[[strides2:.+]] : index to vector<4xindex>
-// CHECK:       LOAD_GATHER: %[[MUL1:.+]] = arith.muli %[[STEP1]], %[[BROAD1]] : vector<4xindex>
-// CHECK:       LOAD_GATHER: %[[BROAD2:.+]] = vector.broadcast %[[strides3:.+]] : index to vector<8xindex>
-// CHECK:       LOAD_GATHER: %[[MUL2:.+]] = arith.muli %[[STEP2]], %[[BROAD2]] : vector<8xindex>
-// CHECK:       LOAD_GATHER: %[[SHAPE0:.+]] = vector.shape_cast %[[MUL0]] : vector<2xindex> to vector<2x1x1x1xindex>
-// CHECK:       LOAD_GATHER: %[[SHAPE1:.+]] = vector.shape_cast %[[MUL1]] : vector<4xindex> to vector<1x4x1x1xindex>
-// CHECK:       LOAD_GATHER: %[[SHAPE2:.+]] = vector.shape_cast %[[MUL2]] : vector<8xindex> to vector<1x1x8x1xindex>
-// CHECK:       LOAD_GATHER: %[[SHAPE3:.+]] = vector.shape_cast %[[STEP3]] : vector<16xindex> to vector<1x1x1x16xindex>
-// CHECK:       LOAD_GATHER: %[[BC0:.+]] = vector.broadcast %[[SHAPE0]] : vector<2x1x1x1xindex> to vector<2x4x8x16xindex>
-// CHECK:       LOAD_GATHER: %[[BC1:.+]] = vector.broadcast %[[SHAPE1]] : vector<1x4x1x1xindex> to vector<2x4x8x16xindex>
-// CHECK:       LOAD_GATHER: %[[BC2:.+]] = vector.broadcast %[[SHAPE2]] : vector<1x1x8x1xindex> to vector<2x4x8x16xindex>
-// CHECK:       LOAD_GATHER: %[[BC3:.+]] = vector.broadcast %[[SHAPE3]] : vector<1x1x1x16xindex> to vector<2x4x8x16xindex>
-// CHECK:       LOAD_GATHER: %[[ADD0:.+]] = arith.addi %[[BC0]], %[[BC1]] : vector<2x4x8x16xindex>
-// CHECK:       LOAD_GATHER: %[[ADD1:.+]] = arith.addi %[[ADD0]], %[[BC2]] : vector<2x4x8x16xindex>
-// CHECK:       LOAD_GATHER: %[[ADD2:.+]] = arith.addi %[[ADD1]], %[[BC3]] : vector<2x4x8x16xindex>
-// CHECK:       LOAD_GATHER: %[[MULI0:.+]] = arith.muli %[[I0]], %[[strides0:.+]] : index
-// CHECK:       LOAD_GATHER: %[[MULI1:.+]] = arith.muli %[[I1]], %[[strides1:.+]] : index
-// CHECK:       LOAD_GATHER: %[[ADDI0:.+]] = arith.addi %[[MULI0]], %[[MULI1]] : index
-// CHECK:       LOAD_GATHER: %[[MULI2:.+]] = arith.muli %[[I2]], %[[strides2:.+]] : index
-// CHECK:       LOAD_GATHER: %[[ADDI1:.+]] = arith.addi %[[ADDI0]], %[[MULI2]] : index
-// CHECK:       LOAD_GATHER: %[[MULI3:.+]] = arith.muli %[[I3]], %[[strides4:.+]] : index
-// CHECK:       LOAD_GATHER: %[[ADDI2:.+]] = arith.addi %[[ADDI1]], %[[MULI3]] : index
-// CHECK:       LOAD_GATHER: %[[ADDI3:.+]] = arith.addi %[[ADDI2]], %[[I4]] : index
-// CHECK:       LOAD_GATHER: %[[SPLAT:.+]] = vector.broadcast %[[ADDI3]] : index to vector<2x4x8x16xindex>
-// CHECK:       LOAD_GATHER: %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[ADD2]] : vector<2x4x8x16xindex>
-// CHECK:       LOAD_GATHER: %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2, 3, 4]{{\]}} : memref<?x?x?x?x?xf32> into memref<?xf32>
-// CHECK:       LOAD_GATHER: %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<?xf32>, vector<2x4x8x16xindex>, vector<2x4x8x16xi1> -> vector<2x4x8x16xf32>
-// CHECK:       LOAD_GATHER: return %[[VEC]]
+// LOAD_ND-LABEL:  @load_dynamic_source3(
+// LOAD_ND:        vector.transfer_read
+
+// LOAD_GATHER-LABEL:  @load_dynamic_source3(
+// LOAD_GATHER-SAME:   %[[SRC:.+]]: memref<?x?x?x?x?xf32>
+// LOAD_GATHER:        %[[CST:.+]] = arith.constant dense<true> : vector<2x4x8x16xi1>
+// LOAD_GATHER:        memref.extract_strided_metadata %[[SRC]] : memref<?x?x?x?x?xf32> -> memref<f32>, index, index, index, index, index, index, index, index, index, index, index
+// LOAD_GATHER-COUNT4: vector.step
+// LOAD_GATHER-COUNT3: vector.broadcast
+// LOAD_GATHER-COUNT4: vector.shape_cast
+// LOAD_GATHER-COUNT4: vector.broadcast {{.*}} : vector<2x4x8x16xindex>
+// LOAD_GATHER-COUNT3: arith.addi {{.*}} : vector<2x4x8x16xindex>
+// LOAD_GATHER:        %[[SPLAT:.+]] = vector.broadcast {{.*}} : index to vector<2x4x8x16xindex>
+// LOAD_GATHER:        %[[IDX:.+]] = arith.addi %[[SPLAT]], {{.*}} : vector<2x4x8x16xindex>
+// LOAD_GATHER:        %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2, 3, 4]{{\]}} : memref<?x?x?x?x?xf32> into memref<?xf32>
+// LOAD_GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<?xf32>, vector<2x4x8x16xindex>, vector<2x4x8x16xi1> -> vector<2x4x8x16xf32>
+// LOAD_GATHER:        return %[[VEC]]
 }
 
 // -----
@@ -303,11 +241,11 @@ gpu.func @no_load_out_of_bounds_non_zero_pad(%source: memref<32x64xf32>,
   gpu.return %0, %1 : vector<8x16xf32>, vector<8x16xf32>
 }
 
-// CHECK-LABEL:   LOAD_ND: @no_load_out_of_bounds_non_zero_pad(
-// CHECK-COUNT-2: LOAD_ND: vector.transfer_read
+// LOAD_ND-LABEL:    @no_load_out_of_bounds_non_zero_pad(
+// LOAD_ND-COUNT-2: vector.transfer_read
 
-// CHECK-LABEL:   LOAD_GATHER: @no_load_out_of_bounds_non_zero_pad(
-// CHECK-COUNT-2: LOAD_GATHER: vector.transfer_read
+// LOAD_GATHER-LABEL: @no_load_out_of_bounds_non_zero_pad(
+// LOAD_GATHER-COUNT-2: vector.transfer_read
 }
 
 // -----
@@ -320,11 +258,11 @@ gpu.func @no_load_out_of_bounds_1D_vector(%source: memref<8x16x32xf32>,
   gpu.return %0 : vector<8xf32>
 }
 
-// CHECK-LABEL: LOAD_ND: @no_load_out_of_bounds_1D_vector(
-// CHECK:       LOAD_ND: vector.transfer_read
+// LOAD_ND-LABEL:  @no_load_out_of_bounds_1D_vector(
+// LOAD_ND:        vector.transfer_read
 
-// CHECK-LABEL: LOAD_GATHER: @no_load_out_of_bounds_1D_vector(
-// CHECK:       LOAD_GATHER: vector.transfer_read
+// LOAD_GATHER-LABEL:  @no_load_out_of_bounds_1D_vector(
+// LOAD_GATHER:        vector.transfer_read
 }
 
 // -----
@@ -338,11 +276,11 @@ gpu.func @no_load_masked(%source : memref<4xf32>,
   gpu.return %0 : vector<4xf32>
 }
 
-// CHECK-LABEL: LOAD_ND: @no_load_masked(
-// CHECK:       LOAD_ND: vector.transfer_read
+// LOAD_ND-LABEL:  @no_load_masked(
+// LOAD_ND:        vector.transfer_read
 
-// CHECK-LABEL: LOAD_GATHER: @no_load_masked(
-// CHECK:       LOAD_GATHER: vector.transfer_read
+// LOAD_GATHER-LABEL:  @no_load_masked(
+// LOAD_GATHER:        vector.transfer_read
 }
 
 // -----
@@ -355,11 +293,11 @@ gpu.func @no_load_tensor(%source: tensor<32x64xf32>,
   gpu.return %0 : vector<8x16xf32>
 }
 
-// CHECK-LABEL: LOAD_ND: @no_load_tensor(
-// CHECK:       LOAD_ND: vector.transfer_read
+// LOAD_ND-LABEL:  @no_load_tensor(
+// LOAD_ND:        vector.transfer_read
 
-// CHECK-LABEL: LOAD_GATHER: @no_load_tensor(
-// CHECK:       LOAD_GATHER: vector.transfer_read
+// LOAD_GATHER-LABEL:  @no_load_tensor(
+// LOAD_GATHER:        vector.transfer_read
 }
 
 // -----
@@ -372,11 +310,11 @@ gpu.func @no_load_high_dim_vector(%source: memref<16x32x64xf32>,
   gpu.return %0 : vector<8x16x32xf32>
 }
 
-// CHECK-LABEL: LOAD_ND: @no_load_high_dim_vector(
-// CHECK:       LOAD_ND: vector.transfer_read
+// LOAD_ND-LABEL:  @no_load_high_dim_vector(
+// LOAD_ND:        vector.transfer_read
 
-// CHECK-LABEL: LOAD_GATHER: @no_load_high_dim_vector(
-// CHECK:       LOAD_GATHER: vector.transfer_read
+// LOAD_GATHER-LABEL:  @no_load_high_dim_vector(
+// LOAD_GATHER:        vector.transfer_read
 }
 
 // -----
@@ -390,11 +328,11 @@ gpu.func @no_load_non_unit_inner_stride(
   gpu.return %0 : vector<8xf32>
 }
 
-// CHECK-LABEL: LOAD_ND: @no_load_non_unit_inner_stride(
-// CHECK:       LOAD_ND: vector.transfer_read
+// LOAD_ND-LABEL:  @no_load_non_unit_inner_stride(
+// LOAD_ND:        vector.transfer_read
 
-// CHECK-LABEL: LOAD_GATHER: @no_load_non_unit_inner_stride(
-// CHECK:       LOAD_GATHER: vector.transfer_read
+// LOAD_GATHER-LABEL:  @no_load_non_unit_inner_stride(
+// LOAD_GATHER:        vector.transfer_read
 }
 
 
@@ -409,11 +347,11 @@ gpu.func @no_load_unsupported_map(%source: memref<16x32x64xf32>,
   gpu.return %0 : vector<8x16xf32>
 }
 
-// CHECK-LABEL: LOAD_ND: @no_load_unsupported_map(
-// CHECK:       LOAD_ND: vector.transfer_read
+// LOAD_ND-LABEL:  @no_load_unsupported_map(
+// LOAD_ND:        vector.transfer_read
 
-// CHECK-LABEL: LOAD_GATHER: @no_load_unsupported_map(
-// CHECK:       LOAD_GATHER: vector.transfer_read
+// LOAD_GATHER-LABEL:  @no_load_unsupported_map(
+// LOAD_GATHER:        vector.transfer_read
 }
 
 // -----
@@ -427,9 +365,19 @@ gpu.func @no_load_transpose_unsupported_data_type(%source: memref<32x64xf16>,
   gpu.return %0 : vector<8x16xf16>
 }
 
-// CHECK-LABEL: LOAD_ND: @no_load_transpose_unsupported_data_type(
-// CHECK:       LOAD_ND: vector.transfer_read
-
-// CHECK-LABEL: LOAD_GATHER: @no_load_transpose_unsupported_data_type(
-// CHECK:       LOAD_GATHER: vector.transfer_read
+// LOAD_ND-LABEL:  @no_load_transpose_unsupported_data_type(
+// LOAD_ND:        vector.transfer_read
+
+// LOAD_GATHER-LABEL:  @no_load_transpose_unsupported_data_type(
+// LOAD_GATHER-SAME:    %[[SRC:.+]]: memref<32x64xf32>,
+// LOAD_GATHER:         %[[CST:.+]] = arith.constant dense<true> : vector<8x16xi1>
+// LOAD_GATHER-COUNT2:  vector.step
+// LOAD_GATHER-COUNT2:  vector.shape_cast
+// LOAD_GATHER-COUNT2: vector.broadcast
+// LOAD_GATHER-COUNT2: arith.muli {{.*}} : index
+// LOAD_GATHER-COUNT2: arith.addi {{.*}} : index
+// LOAD_GATHER:        %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
+// LOAD_GATHER:        %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}}: vector<8x16xindex>
+// LOAD_GATHER:        %[[COLLAPSE:.*]] = memref.collapse_shape %arg0 {{\[\[}}0, 1{{\]\]}} : memref<32x64xf32> into memref<2048xf32>
+// LOAD_GATHER:        %[[LOAD:.*]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<2048xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
 }

>From 0ca20d4a63e196fdf7452d8b7a8d0fd9c895b985 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 6 Aug 2025 23:55:51 +0000
Subject: [PATCH 07/17] fix issues

---
 .../VectorToXeGPU/VectorToXeGPU.cpp           | 25 ++++++-------------
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        | 15 ++++++++---
 .../VectorToXeGPU/transfer-read-to-xegpu.mlir | 21 +++++++++++++---
 mlir/test/Dialect/XeGPU/invalid.mlir          | 10 ++++----
 4 files changed, 41 insertions(+), 30 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 20c11198d67c8..2a6fedbab54ee 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -394,9 +394,9 @@ static Value collapseMemrefTo1D(VectorTransferOpInterface xferOp,
 // Create XeGPU gather load operation
 static LogicalResult createLoadGather(vector::TransferReadOp readOp,
                                       PatternRewriter &rewriter,
-                                      Value flatMemref, Value localOffsets,
-                                      VectorType vectorType) {
+                                      Value flatMemref, Value localOffsets) {
   Location loc = readOp.getLoc();
+  VectorType vectorType = readOp.getVectorType();
   ArrayRef<int64_t> vectorShape = vectorType.getShape();
   Value mask = rewriter.create<vector::ConstantMaskOp>(
       loc, VectorType::get(vectorShape, rewriter.getI1Type()), vectorShape);
@@ -412,10 +412,10 @@ static LogicalResult createLoadGather(vector::TransferReadOp readOp,
 
 // Create XeGPU store scatter operation
 static LogicalResult createStoreScatter(vector::TransferWriteOp writeOp,
-                                        PatternRewriter &rewriter,
-                                        Value flatMemref, Value localOffsets,
-                                        Value value, VectorType vectorType) {
+                                        PatternRewriter &rewriter, Value value,
+                                        Value flatMemref, Value localOffsets) {
   Location loc = writeOp.getLoc();
+  VectorType vectorType = writeOp.getVectorType();
   ArrayRef<int64_t> vectorShape = vectorType.getShape();
   Value mask = rewriter.create<vector::ConstantMaskOp>(
       loc, VectorType::get(vectorShape, rewriter.getI1Type()), vectorShape);
@@ -436,17 +436,13 @@ LogicalResult lowerTransferReadToLoadOp(vector::TransferReadOp readOp,
   if (!memrefType)
     return rewriter.notifyMatchFailure(readOp, "Expected memref source");
 
-  VectorType vectorType = readOp.getVectorType();
-  Type elementType = vectorType.getElementType();
-
   SmallVector<Value> strides = computeStrides(readOp, rewriter);
 
   Value localOffsets = computeGatherOffsets(readOp, rewriter, strides);
 
   Value flatMemref = collapseMemrefTo1D(readOp, rewriter);
 
-  return createLoadGather(readOp, rewriter, flatMemref, localOffsets,
-                          vectorType);
+  return createLoadGather(readOp, rewriter, flatMemref, localOffsets);
 }
 
 LogicalResult lowerTransferWriteToStoreOp(vector::TransferWriteOp writeOp,
@@ -456,13 +452,6 @@ LogicalResult lowerTransferWriteToStoreOp(vector::TransferWriteOp writeOp,
   if (!memrefType)
     return rewriter.notifyMatchFailure(writeOp, "Expected memref source");
 
-  Value baseMemref = writeOp.getBase();
-  AffineMap permMap = writeOp.getPermutationMap();
-  VectorType vectorType = writeOp.getVectorType();
-  Type elementType = vectorType.getElementType();
-  SmallVector<Value> indices(writeOp.getIndices().begin(),
-                             writeOp.getIndices().end());
-
   SmallVector<Value> strides = computeStrides(writeOp, rewriter);
 
   Value localOffsets = computeGatherOffsets(writeOp, rewriter, strides);
@@ -470,7 +459,7 @@ LogicalResult lowerTransferWriteToStoreOp(vector::TransferWriteOp writeOp,
   Value flatMemref = collapseMemrefTo1D(writeOp, rewriter);
 
   return createStoreScatter(writeOp, rewriter, writeOp.getVector(), flatMemref,
-                            localOffsets, vectorType);
+                            localOffsets);
 }
 
 static LogicalResult
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 570689bc0969e..8f67339f7cfe8 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -123,9 +123,18 @@ isValidGatherScatterBufferParams(Type maskTy, VectorType valueTy,
 
   // a valid shape for SIMT case
   if (valueTy.getRank() == 1) {
-    if (valueTy.getNumElements() % chunkSize != 0)
-      return emitError() << "value elements must match chunk size " << chunkSize
-                         << " for SIMT code.";
+    auto maskVecTy = dyn_cast<VectorType>(maskTy);
+    if (!maskVecTy)
+      return emitError() << "Expecting a vector type mask.";
+    int64_t maskElements = maskVecTy.getNumElements();
+
+    auto valueSize = valueTy.getNumElements();
+    if ((valueSize % chunkSize) != 0)
+      return emitError() << "value elements must be multiple of chunk size "
+                         << chunkSize;
+    if ((valueSize / chunkSize) != maskElements)
+      return emitError()
+             << "Mask should match value except the chunk size dim.";
     return success();
   }
 
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
index eebdb02ee2354..ab03416ca8962 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
@@ -314,7 +314,20 @@ gpu.func @no_load_high_dim_vector(%source: memref<16x32x64xf32>,
 // LOAD_ND:        vector.transfer_read
 
 // LOAD_GATHER-LABEL:  @no_load_high_dim_vector(
-// LOAD_GATHER:        vector.transfer_read
+// LOAD_GATHER:        %[[CST:.+]] = arith.constant dense<true> : vector<8x16x32xi1>
+// LOAD_GATHER:        %[[CST_0:.+]] = arith.constant dense<64> : vector<16xindex>
+// LOAD_GATHER:        %[[CST_1:.+]] = arith.constant dense<2048> : vector<8xindex>
+// LOAD_GATHER:        %[[C2048:.+]] = arith.constant 2048 : index
+// LOAD_GATHER:        %[[C64:.+]] = arith.constant 64 : index
+// LOAD_GATHER-COUNT3: vector.step
+// LOAD_GATHER-COUNT3: vector.shape_cast
+// LOAD_GATHER-COUNT3: vector.broadcast {{.*}} : vector<8x16x32xindex>
+// LOAD_GATHER-COUNT2: arith.addi {{.*}} : vector<8x16x32xindex>
+// LOAD_GATHER:        %[[BCASTOFF:.+]] = vector.broadcast {{.*}} : index to vector<8x16x32xindex>
+// LOAD_GATHER:        %[[IDX:.+]] = arith.addi %[[BCASTOFF]], {{.*}} : vector<8x16x32xindex>
+// LOAD_GATHER:        %[[COLLAPSE:.+]] = memref.collapse_shape %arg0 {{\[}}[0, 1, 2]{{\]}} : memref<16x32x64xf32> into memref<32768xf32>
+// LOAD_GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<32768xf32>, vector<8x16x32xindex>, vector<8x16x32xi1> -> vector<8x16x32xf32>
+
 }
 
 // -----
@@ -369,7 +382,7 @@ gpu.func @no_load_transpose_unsupported_data_type(%source: memref<32x64xf16>,
 // LOAD_ND:        vector.transfer_read
 
 // LOAD_GATHER-LABEL:  @no_load_transpose_unsupported_data_type(
-// LOAD_GATHER-SAME:    %[[SRC:.+]]: memref<32x64xf32>,
+// LOAD_GATHER-SAME:    %[[SRC:.+]]: memref<32x64xf16>,
 // LOAD_GATHER:         %[[CST:.+]] = arith.constant dense<true> : vector<8x16xi1>
 // LOAD_GATHER-COUNT2:  vector.step
 // LOAD_GATHER-COUNT2:  vector.shape_cast
@@ -378,6 +391,6 @@ gpu.func @no_load_transpose_unsupported_data_type(%source: memref<32x64xf16>,
 // LOAD_GATHER-COUNT2: arith.addi {{.*}} : index
 // LOAD_GATHER:        %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
 // LOAD_GATHER:        %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}}: vector<8x16xindex>
-// LOAD_GATHER:        %[[COLLAPSE:.*]] = memref.collapse_shape %arg0 {{\[\[}}0, 1{{\]\]}} : memref<32x64xf32> into memref<2048xf32>
-// LOAD_GATHER:        %[[LOAD:.*]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<2048xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
+// LOAD_GATHER:        %[[COLLAPSE:.*]] = memref.collapse_shape %arg0 {{\[\[}}0, 1{{\]\]}} : memref<32x64xf16> into memref<2048xf16>
+// LOAD_GATHER:        %[[LOAD:.*]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<2048xf16>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf16>
 }
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index dff3ffab39ecf..9a1a3de9e233a 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -407,8 +407,8 @@ func.func @load_gather_offset_sg(%src: memref<?xf16>) {
 func.func @load_gather_offset_wi(%src: ui64) {
   %mask = arith.constant dense<1>: vector<1xi1>
   %offsets = arith.constant dense<[0]> : vector<1xindex>
-  // expected-error at +1 {{value elements must match chunk size}}
-  %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : ui64,  vector<1xindex>, vector<1xi1> -> vector<4xf32>
+  // expected-error at +1 {{value elements must be multiple of chunk size}}
+  %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : ui64,  vector<1xindex>, vector<1xi1> -> vector<3xf32>
   return
 }
 
@@ -417,7 +417,7 @@ func.func @store_scatter_offset_wi_1(%src: memref<?xf16>) {
   %val = arith.constant dense<2.9>: vector<4xf16>
   %offsets = arith.constant dense<[0]> : vector<1xindex>
   %mask = arith.constant dense<1>: vector<1xi1>
-  // expected-error at +1 {{value elements must match chunk size}}
+  // expected-error at +1 {{Mask should match value except the chunk size dim}}
   xegpu.store %val, %src[%offsets], %mask 
         : vector<4xf16>, memref<?xf16>, vector<1xindex>, vector<1xi1>
   return
@@ -438,8 +438,8 @@ func.func @store_scatter_offset_wi_2(%src: memref<4x4xf16>) {
 func.func @load_gather_offset_wi_2(%src: ui64) {
   %mask = arith.constant dense<1>: vector<1xi1>
   %offsets = arith.constant dense<[0]> : vector<1xindex>
-  // expected-error at +1 {{value elements must match chunk size}}
-  %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : ui64,  vector<1xindex>, vector<1xi1> -> vector<4xf16>
+  // expected-error at +1 {{value elements must be multiple of chunk size}}
+  %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : ui64,  vector<1xindex>, vector<1xi1> -> vector<3xf16>
   return
 }
 

>From a442adc37e602cee75e38e8dcdc7b9f5acb77819 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 7 Aug 2025 01:32:39 +0000
Subject: [PATCH 08/17] enable transfer-write lowering

---
 .../VectorToXeGPU/VectorToXeGPU.cpp           |  88 ++--
 .../VectorToXeGPU/transfer-read-to-xegpu.mlir | 392 +++++++++---------
 .../transfer-write-to-xegpu.mlir              | 200 ++++++---
 3 files changed, 386 insertions(+), 294 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 2a6fedbab54ee..c59a2060de6a3 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -152,43 +152,40 @@ createNdDescriptor(PatternRewriter &rewriter, Location loc,
   return ndDesc;
 }
 
-static void adjustStridesForPermutation(Operation *op,
-                                        PatternRewriter &rewriter,
-                                        MemRefType memrefType,
-                                        AffineMap permMap, VectorType vecType,
-                                        SmallVectorImpl<Value> &strides) {
+static LogicalResult adjustStridesForPermutation(
+    Operation *op, PatternRewriter &rewriter, MemRefType memrefType,
+    AffineMap permMap, VectorType vecType, SmallVectorImpl<Value> &strides) {
   unsigned vecRank;
   unsigned memrefRank = memrefType.getRank();
 
-  if (!permMap.isMinorIdentity()) {
-    vecRank = vecType.getRank();
-    // Only adjust the last vecRank strides according to the permutation
-    ArrayRef<Value> relevantStrides =
-        ArrayRef<Value>(strides).take_back(vecRank);
-    SmallVector<Value> adjustedStrides(vecRank);
-    // For each output dimension in the permutation map, find which input dim it
-    // refers to, and assign the corresponding stride.
-    for (unsigned outIdx = 0; outIdx < vecRank; ++outIdx) {
-      AffineExpr expr = permMap.getResult(outIdx);
-      auto dimExpr = dyn_cast<AffineDimExpr>(expr);
-      if (!dimExpr) {
-        rewriter.notifyMatchFailure(op, "Unsupported permutation expr");
-        return;
-      }
-      unsigned pos = dimExpr.getPosition();
-      // Map permutation to the relevant strides (innermost dims)
-      if (pos < memrefRank - vecRank) {
-        rewriter.notifyMatchFailure(op, "Permutation out of bounds");
-        return;
-      }
-      // The stride for output dimension outIdx is the stride of input dimension
-      // pos
-      adjustedStrides[outIdx] = relevantStrides[pos - (memrefRank - vecRank)];
+  if (permMap.isMinorIdentity())
+    return success();
+  vecRank = vecType.getRank();
+  // Only adjust the last vecRank strides according to the permutation
+  ArrayRef<Value> relevantStrides = ArrayRef<Value>(strides).take_back(vecRank);
+  SmallVector<Value> adjustedStrides(vecRank);
+  // For each output dimension in the permutation map, find which input dim it
+  // refers to, and assign the corresponding stride.
+  for (unsigned outIdx = 0; outIdx < vecRank; ++outIdx) {
+    AffineExpr expr = permMap.getResult(outIdx);
+    auto dimExpr = dyn_cast<AffineDimExpr>(expr);
+    if (!dimExpr) {
+      return rewriter.notifyMatchFailure(op, "Unsupported permutation expr");
     }
-    // Replace the last vecRank strides with the adjusted ones
-    for (unsigned i = 0; i < vecRank; ++i)
-      strides[memrefRank - vecRank + i] = adjustedStrides[i];
+    unsigned pos = dimExpr.getPosition();
+    // Map permutation to the relevant strides (innermost dims)
+    if (pos < memrefRank - vecRank) {
+      return rewriter.notifyMatchFailure(op, "Permutation out of bounds");
+    }
+    // The stride for output dimension outIdx is the stride of input dimension
+    // pos
+    adjustedStrides[outIdx] = relevantStrides[pos - (memrefRank - vecRank)];
   }
+  // Replace the last vecRank strides with the adjusted ones
+  for (unsigned i = 0; i < vecRank; ++i)
+    strides[memrefRank - vecRank + i] = adjustedStrides[i];
+
+  return success();
 }
 
 SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
@@ -204,7 +201,6 @@ SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
     int64_t offset;
     SmallVector<int64_t> intStrides;
     if (failed(memrefType.getStridesAndOffset(intStrides, offset))) {
-      rewriter.notifyMatchFailure(xferOp, "Failed to get memref strides");
       return {};
     }
     // Wrap static strides as MLIR values
@@ -234,8 +230,10 @@ SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
     strides.append(meta.getStrides().begin(), meta.getStrides().end());
   }
   // Adjust strides according to the permutation map (e.g., for transpose)
-  adjustStridesForPermutation(xferOp, rewriter, memrefType, permMap, vectorType,
-                              strides);
+  if (failed(adjustStridesForPermutation(xferOp, rewriter, memrefType, permMap,
+                                         vectorType, strides))) {
+    return {};
+  }
   return strides;
 }
 
@@ -437,6 +435,8 @@ LogicalResult lowerTransferReadToLoadOp(vector::TransferReadOp readOp,
     return rewriter.notifyMatchFailure(readOp, "Expected memref source");
 
   SmallVector<Value> strides = computeStrides(readOp, rewriter);
+  if (strides.empty())
+    return rewriter.notifyMatchFailure(readOp, "Failed to compute strides");
 
   Value localOffsets = computeGatherOffsets(readOp, rewriter, strides);
 
@@ -463,30 +463,30 @@ LogicalResult lowerTransferWriteToStoreOp(vector::TransferWriteOp writeOp,
 }
 
 static LogicalResult
-extraCheckForScatteredLoadStore(vector::TransferReadOp readOp,
+extraCheckForScatteredLoadStore(VectorTransferOpInterface xferOp,
                                 PatternRewriter &rewriter) {
   // 1. it must be inbound access by checking in_bounds attributes, like
   // {in_bounds = [false, true]}
-  if (readOp.hasOutOfBoundsDim())
-    return rewriter.notifyMatchFailure(readOp,
+  if (xferOp.hasOutOfBoundsDim())
+    return rewriter.notifyMatchFailure(xferOp,
                                        "Out-of-bounds access is not supported "
                                        "for scatter load/store lowering");
   // 2. if the memref has static shape, its lower rank must exactly match with
   // vector shape.
-  if (auto memrefType = dyn_cast<MemRefType>(readOp.getShapedType())) {
+  if (auto memrefType = dyn_cast<MemRefType>(xferOp.getShapedType())) {
     if (memrefType.hasStaticShape()) {
       ArrayRef<int64_t> memrefShape = memrefType.getShape();
-      ArrayRef<int64_t> vectorShape = readOp.getVectorType().getShape();
+      ArrayRef<int64_t> vectorShape = xferOp.getVectorType().getShape();
       size_t memrefRank = memrefShape.size();
       size_t vectorRank = vectorShape.size();
       if (vectorRank > memrefRank)
         return rewriter.notifyMatchFailure(
-            readOp, "Vector rank cannot exceed memref rank");
+            xferOp, "Vector rank cannot exceed memref rank");
       // Compare the last vectorRank dimensions of memref with vector shape
       for (size_t i = 0; i < vectorRank; ++i) {
         if (memrefShape[memrefRank - vectorRank + i] <= vectorShape[i])
           return rewriter.notifyMatchFailure(
-              readOp, "Memref lower dimensions must match vector shape");
+              xferOp, "Memref lower dimensions must match vector shape");
       }
     }
   }
@@ -574,9 +574,13 @@ struct TransferWriteLowering
 
     auto chip = xegpu::getXeGPUChipStr(writeOp);
     if (chip != "pvc" && chip != "bmg") {
+      // perform additional checks -
+      if (failed(extraCheckForScatteredLoadStore(writeOp, rewriter)))
+        return failure();
       // calling another function that lower TransferWriteOp to regular StoreOp
       return lowerTransferWriteToStoreOp(writeOp, rewriter);
     }
+
     // Perform common data transfer checks.
     VectorType vecTy = writeOp.getVectorType();
     if (failed(storeLoadPreconditions(rewriter, writeOp, vecTy)))
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
index ab03416ca8962..33228b2b3c4e2 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
@@ -1,5 +1,5 @@
-// RUN: mlir-opt %s --xevm-attach-target='module=xevm_* O=3 chip=pvc' -convert-vector-to-xegpu -split-input-file | FileCheck %s --check-prefix=LOAD_ND
-// RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s --check-prefix=LOAD_GATHER
+// RUN: mlir-opt %s --xevm-attach-target='module=xevm_* O=3 chip=pvc' -convert-vector-to-xegpu -split-input-file | FileCheck %s --check-prefix=LOAD-ND
+// RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s --check-prefix=LOAD-GATHER
 
 gpu.module @xevm_module {
 gpu.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vector<8xf32> {
@@ -9,26 +9,26 @@ gpu.func @load_1D_vector(%source: memref<8x16x32xf32>, %offset: index) -> vector
   gpu.return %0 : vector<8xf32>
 }
 
-// LOAD_ND-LABEL:  @load_1D_vector(
-// LOAD_ND-SAME:   %[[SRC:.+]]: memref<8x16x32xf32>,
-// LOAD_ND-SAME:   %[[OFFSET:.+]]: index
-// LOAD_ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc
-// LOAD_ND-SAME:     %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
-// LOAD_ND-SAME:     memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
-// LOAD_ND-SAME:     boundary_check = false
-// LOAD_ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf32>
-// LOAD_ND:        return %[[VEC]]
-
-// LOAD_GATHER-LABEL:  @load_1D_vector(
-// LOAD_GATHER-SAME:   %[[SRC:.+]]: memref<8x16x32xf32>,
-// LOAD_GATHER:        %[[CST:.+]] = arith.constant dense<true> : vector<8xi1>
-// LOAD_GATHER:        %[[STEP:.+]] = vector.step : vector<8xindex>
-// LOAD_GATHER-COUNT2: arith.muli {{.*}} : index
-// LOAD_GATHER-COUNT2: arith.addi {{.*}} : index
-// LOAD_GATHER:        %[[SPLAT:.+]] = vector.broadcast {{.*}}:  index to vector<8xindex>
-// LOAD_GATHER:        %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[STEP]] : vector<8xindex>
-// LOAD_GATHER:        %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
-// LOAD_GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<4096xf32>, vector<8xindex>, vector<8xi1> -> vector<8xf32>
+// LOAD-ND-LABEL:  @load_1D_vector(
+// LOAD-ND-SAME:   %[[SRC:.+]]: memref<8x16x32xf32>,
+// LOAD-ND-SAME:   %[[OFFSET:.+]]: index
+// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc
+// LOAD-ND-SAME:     %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// LOAD-ND-SAME:     memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
+// LOAD-ND-SAME:     boundary_check = false
+// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8xf32>
+// LOAD-ND:        return %[[VEC]]
+
+// LOAD-GATHER-LABEL:  @load_1D_vector(
+// LOAD-GATHER-SAME:   %[[SRC:.+]]: memref<8x16x32xf32>,
+// LOAD-GATHER:        %[[CST:.+]] = arith.constant dense<true> : vector<8xi1>
+// LOAD-GATHER:        %[[STEP:.+]] = vector.step : vector<8xindex>
+// LOAD-GATHER-COUNT2: arith.muli {{.*}} : index
+// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
+// LOAD-GATHER:        %[[SPLAT:.+]] = vector.broadcast {{.*}}:  index to vector<8xindex>
+// LOAD-GATHER:        %[[IDX:.+]] = arith.addi %[[SPLAT]], %[[STEP]] : vector<8xindex>
+// LOAD-GATHER:        %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
+// LOAD-GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<4096xf32>, vector<8xindex>, vector<8xi1> -> vector<8xf32>
 
 }
 
@@ -42,28 +42,28 @@ gpu.func @load_2D_vector(%source: memref<8x16x32xf32>,
   gpu.return %0 : vector<8x16xf32>
 }
 
-// LOAD_ND-LABEL:  @load_2D_vector(
-// LOAD_ND-SAME:   %[[SRC:.+]]: memref<8x16x32xf32>,
-// LOAD_ND-SAME:   %[[OFFSET:.+]]: index
-// LOAD_ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc
-// LOAD_ND-SAME:     %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
-// LOAD_ND-SAME:     memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
-// LOAD_ND-SAME:     boundary_check = false
-// LOAD_ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
-// LOAD_ND:        return %[[VEC]]
-
-// LOAD_GATHER-LABEL:  @load_2D_vector(
-// LOAD_GATHER-SAME:   %[[SRC:.+]]: memref<8x16x32xf32>,
-// LOAD_GATHER:        %[[CST:.+]] = arith.constant dense<true> : vector<8x16xi1>
-// LOAD_GATHER-COUNT2: vector.step
-// LOAD_GATHER-COUNT2: vector.shape_cast
-// LOAD_GATHER-COUNT2: vector.broadcast
-// LOAD_GATHER-COUNT2: arith.muli {{.*}} : index
-// LOAD_GATHER-COUNT2: arith.addi {{.*}} : index
-// LOAD_GATHER:        %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8x16xindex>
-// LOAD_GATHER:        %[[IDX:.+]] = arith.addi %[[SPLAT]], {{.*}}: vector<8x16xindex>
-// LOAD_GATHER:        %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
-// LOAD_GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<4096xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
+// LOAD-ND-LABEL:  @load_2D_vector(
+// LOAD-ND-SAME:   %[[SRC:.+]]: memref<8x16x32xf32>,
+// LOAD-ND-SAME:   %[[OFFSET:.+]]: index
+// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc
+// LOAD-ND-SAME:     %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// LOAD-ND-SAME:     memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
+// LOAD-ND-SAME:     boundary_check = false
+// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// LOAD-ND:        return %[[VEC]]
+
+// LOAD-GATHER-LABEL:  @load_2D_vector(
+// LOAD-GATHER-SAME:   %[[SRC:.+]]: memref<8x16x32xf32>,
+// LOAD-GATHER:        %[[CST:.+]] = arith.constant dense<true> : vector<8x16xi1>
+// LOAD-GATHER-COUNT2: vector.step
+// LOAD-GATHER-COUNT2: vector.shape_cast
+// LOAD-GATHER-COUNT2: vector.broadcast
+// LOAD-GATHER-COUNT2: arith.muli {{.*}} : index
+// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
+// LOAD-GATHER:        %[[SPLAT:.+]] = vector.broadcast {{.*}}: index to vector<8x16xindex>
+// LOAD-GATHER:        %[[IDX:.+]] = arith.addi %[[SPLAT]], {{.*}}: vector<8x16xindex>
+// LOAD-GATHER:        %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
+// LOAD-GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<4096xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
 
 }
 
@@ -78,16 +78,16 @@ gpu.func @load_zero_pad_out_of_bounds(%source: memref<32x64xf32>,
   gpu.return %0 : vector<8x16xf32>
 }
 
-// LOAD_ND-LABEL:  @load_zero_pad_out_of_bounds(
-// LOAD_ND-SAME:   %[[SRC:.+]]: memref<32x64xf32>,
-// LOAD_ND-SAME:   %[[OFFSET:.+]]: index
-// LOAD_ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
-// LOAD_ND-SAME:     memref<32x64xf32> -> !xegpu.tensor_desc<8x16xf32>
-// LOAD_ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
-// LOAD_ND:        return %[[VEC]]
+// LOAD-ND-LABEL:  @load_zero_pad_out_of_bounds(
+// LOAD-ND-SAME:   %[[SRC:.+]]: memref<32x64xf32>,
+// LOAD-ND-SAME:   %[[OFFSET:.+]]: index
+// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
+// LOAD-ND-SAME:     memref<32x64xf32> -> !xegpu.tensor_desc<8x16xf32>
+// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// LOAD-ND:        return %[[VEC]]
 
-// LOAD_GATHER-LABEL:  @load_zero_pad_out_of_bounds(
-// LOAD_GATHER:        vector.transfer_read
+// LOAD-GATHER-LABEL:  @load_zero_pad_out_of_bounds(
+// LOAD-GATHER:        vector.transfer_read
 
 }
 
@@ -103,29 +103,29 @@ gpu.func @load_transposed(%source: memref<32x64xf32>,
   gpu.return %0 : vector<8x16xf32>
 }
 
-// LOAD_ND-LABEL:  @load_transposed(
-// LOAD_ND-SAME:   %[[SRC:.+]]: memref<32x64xf32>,
-// LOAD_ND-SAME:   %[[OFFSET1:.+]]: index, 
-// LOAD_ND-SAME:   %[[OFFSET2:.+]]: index  
-// LOAD_ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET1]], %[[OFFSET2]]]
-// LOAD_ND-SAME:     memref<32x64xf32> -> !xegpu.tensor_desc<16x8xf32
-// LOAD_ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]] <{transpose = array<i64: 1, 0>}>
-// LOAD_ND-SAME:     -> vector<8x16xf32>
-// LOAD_ND:        return %[[VEC]]
-
-
-// LOAD_GATHER-LABEL:  @load_transposed(
-// LOAD_GATHER-SAME:    %[[SRC:.+]]: memref<32x64xf32>,
-// LOAD_GATHER:         %[[CST:.+]] = arith.constant dense<true> : vector<8x16xi1>
-// LOAD_GATHER-COUNT2:  vector.step
-// LOAD_GATHER-COUNT2:  vector.shape_cast
-// LOAD_GATHER-COUNT2: vector.broadcast
-// LOAD_GATHER-COUNT2: arith.muli {{.*}} : index
-// LOAD_GATHER-COUNT2: arith.addi {{.*}} : index
-// LOAD_GATHER:        %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
-// LOAD_GATHER:        %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}}: vector<8x16xindex>
-// LOAD_GATHER:        %[[COLLAPSE:.*]] = memref.collapse_shape %arg0 {{\[\[}}0, 1{{\]\]}} : memref<32x64xf32> into memref<2048xf32>
-// LOAD_GATHER:        %[[LOAD:.*]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<2048xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
+// LOAD-ND-LABEL:  @load_transposed(
+// LOAD-ND-SAME:   %[[SRC:.+]]: memref<32x64xf32>,
+// LOAD-ND-SAME:   %[[OFFSET1:.+]]: index, 
+// LOAD-ND-SAME:   %[[OFFSET2:.+]]: index  
+// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET1]], %[[OFFSET2]]]
+// LOAD-ND-SAME:     memref<32x64xf32> -> !xegpu.tensor_desc<16x8xf32
+// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]] <{transpose = array<i64: 1, 0>}>
+// LOAD-ND-SAME:     -> vector<8x16xf32>
+// LOAD-ND:        return %[[VEC]]
+
+
+// LOAD-GATHER-LABEL:  @load_transposed(
+// LOAD-GATHER-SAME:    %[[SRC:.+]]: memref<32x64xf32>,
+// LOAD-GATHER:         %[[CST:.+]] = arith.constant dense<true> : vector<8x16xi1>
+// LOAD-GATHER-COUNT2:  vector.step
+// LOAD-GATHER-COUNT2:  vector.shape_cast
+// LOAD-GATHER-COUNT2: vector.broadcast
+// LOAD-GATHER-COUNT2: arith.muli {{.*}} : index
+// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
+// LOAD-GATHER:        %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
+// LOAD-GATHER:        %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}}: vector<8x16xindex>
+// LOAD-GATHER:        %[[COLLAPSE:.*]] = memref.collapse_shape %arg0 {{\[\[}}0, 1{{\]\]}} : memref<32x64xf32> into memref<2048xf32>
+// LOAD-GATHER:        %[[LOAD:.*]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<2048xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
 
 }
 
@@ -138,35 +138,35 @@ gpu.func @load_dynamic_source(%source: memref<?x?x?xf32>,
     {in_bounds = [true, true]} : memref<?x?x?xf32>, vector<8x16xf32>
   gpu.return %0 : vector<8x16xf32>
 }
-// LOAD_ND-LABEL:  @load_dynamic_source(
-// LOAD_ND-SAME:   %[[SRC:.+]]: memref<?x?x?xf32>,
-// LOAD_ND-SAME:   %[[OFFSET:.+]]: index
-// LOAD_ND:        %[[C2:.+]] = arith.constant 2 : index
-// LOAD_ND:        %[[C1:.+]] = arith.constant 1 : index
-// LOAD_ND:        %[[C0:.+]] = arith.constant 0 : index
-// LOAD_ND-DAG:    %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
-// LOAD_ND-DAG:    %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
-// LOAD_ND-DAG:    %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
-// LOAD_ND:        %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
-// LOAD_ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET:.+]], %[[OFFSET:.+]], %[[OFFSET:.+]]]
-// LOAD_ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
-// LOAD_ND:        return %[[VEC]]
-
-
-// LOAD_GATHER-LABEL:  @load_dynamic_source(
-// LOAD_GATHER-SAME:   %[[ARG0:.+]]: memref<?x?x?xf32>,
-// LOAD_GATHER:        %[[CST:.+]] = arith.constant dense<true> : vector<8x16xi1>
-// LOAD_GATHER:        memref.extract_strided_metadata %[[ARG0]]
-// LOAD_GATHER-COUNT2: vector.step
-// LOAD_GATHER-COUNT2: vector.shape_cast
-// LOAD_GATHER-COUNT2: vector.broadcast
-// LOAD_GATHER-COUNT2: arith.muli {{.*}} : index
-// LOAD_GATHER-COUNT2: arith.addi {{.*}} : index
-// LOAD_GATHER:        %[[BROADIDX:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
-// LOAD_GATHER:        %[[FINALIDX:.+]] = arith.addi %[[BROADIDX]], {{.*}} : vector<8x16xindex>
-// LOAD_GATHER:        %[[COLLAPSE:.+]] = memref.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2]{{\]}} : memref<?x?x?xf32> into memref<?xf32>
-// LOAD_GATHER:        %[[RES:.+]] = xegpu.load %[[COLLAPSE]][%[[FINALIDX]]], %[[CST]] : memref<?xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
-// LOAD_GATHER:        gpu.return %[[RES]] : vector<8x16xf32>
+// LOAD-ND-LABEL:  @load_dynamic_source(
+// LOAD-ND-SAME:   %[[SRC:.+]]: memref<?x?x?xf32>,
+// LOAD-ND-SAME:   %[[OFFSET:.+]]: index
+// LOAD-ND:        %[[C2:.+]] = arith.constant 2 : index
+// LOAD-ND:        %[[C1:.+]] = arith.constant 1 : index
+// LOAD-ND:        %[[C0:.+]] = arith.constant 0 : index
+// LOAD-ND-DAG:    %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
+// LOAD-ND-DAG:    %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
+// LOAD-ND-DAG:    %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
+// LOAD-ND:        %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
+// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET:.+]], %[[OFFSET:.+]], %[[OFFSET:.+]]]
+// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
+// LOAD-ND:        return %[[VEC]]
+
+
+// LOAD-GATHER-LABEL:  @load_dynamic_source(
+// LOAD-GATHER-SAME:   %[[ARG0:.+]]: memref<?x?x?xf32>,
+// LOAD-GATHER:        %[[CST:.+]] = arith.constant dense<true> : vector<8x16xi1>
+// LOAD-GATHER:        memref.extract_strided_metadata %[[ARG0]]
+// LOAD-GATHER-COUNT2: vector.step
+// LOAD-GATHER-COUNT2: vector.shape_cast
+// LOAD-GATHER-COUNT2: vector.broadcast
+// LOAD-GATHER-COUNT2: arith.muli {{.*}} : index
+// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
+// LOAD-GATHER:        %[[BROADIDX:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
+// LOAD-GATHER:        %[[FINALIDX:.+]] = arith.addi %[[BROADIDX]], {{.*}} : vector<8x16xindex>
+// LOAD-GATHER:        %[[COLLAPSE:.+]] = memref.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2]{{\]}} : memref<?x?x?xf32> into memref<?xf32>
+// LOAD-GATHER:        %[[RES:.+]] = xegpu.load %[[COLLAPSE]][%[[FINALIDX]]], %[[CST]] : memref<?xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32>
+// LOAD-GATHER:        gpu.return %[[RES]] : vector<8x16xf32>
 }
 
 // -----
@@ -179,24 +179,24 @@ gpu.func @load_dynamic_source2(%source: memref<?x8x16xf32>,
   gpu.return %0 : vector<8x16xf32>
 }
 
-// LOAD_ND-LABEL:  @load_dynamic_source2(
-// LOAD_ND-DAG:    %[[C0:.+]] = arith.constant 0 : index
-// LOAD_ND-DAG:    %[[DIM:.+]] = memref.dim %{{.*}}, %[[C0]] : memref<?x8x16xf32>
-// LOAD_ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}], shape : [%[[DIM]], 8, 16], strides : [128, 16, 1] : memref<?x8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
-// LOAD_ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]] : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>> -> vector<8x16xf32>
-// LOAD_ND:        return %[[VEC]] : vector<8x16xf32>
-
-// LOAD_GATHER-LABEL:  @load_dynamic_source2(
-// LOAD_GATHER-DAG:    %[[CST_0:.+]] = arith.constant dense<true> : vector<8x16xi1>
-// LOAD_GATHER-COUNT2: vector.step
-// LOAD_GATHER-COUNT2: vector.shape_cast
-// LOAD_GATHER-COUNT2: vector.broadcast
-// LOAD_GATHER-COUNT2: arith.muli {{.*}} : index
-// LOAD_GATHER-COUNT2: arith.addi {{.*}} : index
-// LOAD_GATHER-DAG:    %[[BCASTIDX:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
-// LOAD_GATHER-DAG:    %[[OFFSETS:.+]] = arith.addi %[[BCASTIDX]], {{.*}} : vector<8x16xindex>
-// LOAD_GATHER-DAG:    %[[COLLAPSE:.+]] = memref.collapse_shape %arg0 {{\[}}[0, 1, 2]{{\]}} : memref<?x8x16xf32> into memref<?xf32>
-// LOAD_GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[OFFSETS]]{{\]}}, %[[CST_0]] : memref<?xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32> 
+// LOAD-ND-LABEL:  @load_dynamic_source2(
+// LOAD-ND-DAG:    %[[C0:.+]] = arith.constant 0 : index
+// LOAD-ND-DAG:    %[[DIM:.+]] = memref.dim %{{.*}}, %[[C0]] : memref<?x8x16xf32>
+// LOAD-ND:        %[[DESC:.+]] = xegpu.create_nd_tdesc %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}], shape : [%[[DIM]], 8, 16], strides : [128, 16, 1] : memref<?x8x16xf32> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>>
+// LOAD-ND:        %[[VEC:.+]] = xegpu.load_nd %[[DESC]] : !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<boundary_check = false>> -> vector<8x16xf32>
+// LOAD-ND:        return %[[VEC]] : vector<8x16xf32>
+
+// LOAD-GATHER-LABEL:  @load_dynamic_source2(
+// LOAD-GATHER-DAG:    %[[CST_0:.+]] = arith.constant dense<true> : vector<8x16xi1>
+// LOAD-GATHER-COUNT2: vector.step
+// LOAD-GATHER-COUNT2: vector.shape_cast
+// LOAD-GATHER-COUNT2: vector.broadcast
+// LOAD-GATHER-COUNT2: arith.muli {{.*}} : index
+// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
+// LOAD-GATHER-DAG:    %[[BCASTIDX:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
+// LOAD-GATHER-DAG:    %[[OFFSETS:.+]] = arith.addi %[[BCASTIDX]], {{.*}} : vector<8x16xindex>
+// LOAD-GATHER-DAG:    %[[COLLAPSE:.+]] = memref.collapse_shape %arg0 {{\[}}[0, 1, 2]{{\]}} : memref<?x8x16xf32> into memref<?xf32>
+// LOAD-GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[OFFSETS]]{{\]}}, %[[CST_0]] : memref<?xf32>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf32> 
 
 }
 
@@ -210,23 +210,23 @@ gpu.func @load_dynamic_source3(%source: memref<?x?x?x?x?xf32>,
   gpu.return %0 : vector<2x4x8x16xf32>
 }
 
-// LOAD_ND-LABEL:  @load_dynamic_source3(
-// LOAD_ND:        vector.transfer_read
-
-// LOAD_GATHER-LABEL:  @load_dynamic_source3(
-// LOAD_GATHER-SAME:   %[[SRC:.+]]: memref<?x?x?x?x?xf32>
-// LOAD_GATHER:        %[[CST:.+]] = arith.constant dense<true> : vector<2x4x8x16xi1>
-// LOAD_GATHER:        memref.extract_strided_metadata %[[SRC]] : memref<?x?x?x?x?xf32> -> memref<f32>, index, index, index, index, index, index, index, index, index, index, index
-// LOAD_GATHER-COUNT4: vector.step
-// LOAD_GATHER-COUNT3: vector.broadcast
-// LOAD_GATHER-COUNT4: vector.shape_cast
-// LOAD_GATHER-COUNT4: vector.broadcast {{.*}} : vector<2x4x8x16xindex>
-// LOAD_GATHER-COUNT3: arith.addi {{.*}} : vector<2x4x8x16xindex>
-// LOAD_GATHER:        %[[SPLAT:.+]] = vector.broadcast {{.*}} : index to vector<2x4x8x16xindex>
-// LOAD_GATHER:        %[[IDX:.+]] = arith.addi %[[SPLAT]], {{.*}} : vector<2x4x8x16xindex>
-// LOAD_GATHER:        %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2, 3, 4]{{\]}} : memref<?x?x?x?x?xf32> into memref<?xf32>
-// LOAD_GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<?xf32>, vector<2x4x8x16xindex>, vector<2x4x8x16xi1> -> vector<2x4x8x16xf32>
-// LOAD_GATHER:        return %[[VEC]]
+// LOAD-ND-LABEL:  @load_dynamic_source3(
+// LOAD-ND:        vector.transfer_read
+
+// LOAD-GATHER-LABEL:  @load_dynamic_source3(
+// LOAD-GATHER-SAME:   %[[SRC:.+]]: memref<?x?x?x?x?xf32>
+// LOAD-GATHER:        %[[CST:.+]] = arith.constant dense<true> : vector<2x4x8x16xi1>
+// LOAD-GATHER:        memref.extract_strided_metadata %[[SRC]] : memref<?x?x?x?x?xf32> -> memref<f32>, index, index, index, index, index, index, index, index, index, index, index
+// LOAD-GATHER-COUNT4: vector.step
+// LOAD-GATHER-COUNT3: vector.broadcast
+// LOAD-GATHER-COUNT4: vector.shape_cast
+// LOAD-GATHER-COUNT4: vector.broadcast {{.*}} : vector<2x4x8x16xindex>
+// LOAD-GATHER-COUNT3: arith.addi {{.*}} : vector<2x4x8x16xindex>
+// LOAD-GATHER:        %[[SPLAT:.+]] = vector.broadcast {{.*}} : index to vector<2x4x8x16xindex>
+// LOAD-GATHER:        %[[IDX:.+]] = arith.addi %[[SPLAT]], {{.*}} : vector<2x4x8x16xindex>
+// LOAD-GATHER:        %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2, 3, 4]{{\]}} : memref<?x?x?x?x?xf32> into memref<?xf32>
+// LOAD-GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : memref<?xf32>, vector<2x4x8x16xindex>, vector<2x4x8x16xi1> -> vector<2x4x8x16xf32>
+// LOAD-GATHER:        return %[[VEC]]
 }
 
 // -----
@@ -241,11 +241,11 @@ gpu.func @no_load_out_of_bounds_non_zero_pad(%source: memref<32x64xf32>,
   gpu.return %0, %1 : vector<8x16xf32>, vector<8x16xf32>
 }
 
-// LOAD_ND-LABEL:    @no_load_out_of_bounds_non_zero_pad(
-// LOAD_ND-COUNT-2: vector.transfer_read
+// LOAD-ND-LABEL:    @no_load_out_of_bounds_non_zero_pad(
+// LOAD-ND-COUNT-2: vector.transfer_read
 
-// LOAD_GATHER-LABEL: @no_load_out_of_bounds_non_zero_pad(
-// LOAD_GATHER-COUNT-2: vector.transfer_read
+// LOAD-GATHER-LABEL: @no_load_out_of_bounds_non_zero_pad(
+// LOAD-GATHER-COUNT-2: vector.transfer_read
 }
 
 // -----
@@ -258,11 +258,11 @@ gpu.func @no_load_out_of_bounds_1D_vector(%source: memref<8x16x32xf32>,
   gpu.return %0 : vector<8xf32>
 }
 
-// LOAD_ND-LABEL:  @no_load_out_of_bounds_1D_vector(
-// LOAD_ND:        vector.transfer_read
+// LOAD-ND-LABEL:  @no_load_out_of_bounds_1D_vector(
+// LOAD-ND:        vector.transfer_read
 
-// LOAD_GATHER-LABEL:  @no_load_out_of_bounds_1D_vector(
-// LOAD_GATHER:        vector.transfer_read
+// LOAD-GATHER-LABEL:  @no_load_out_of_bounds_1D_vector(
+// LOAD-GATHER:        vector.transfer_read
 }
 
 // -----
@@ -276,11 +276,11 @@ gpu.func @no_load_masked(%source : memref<4xf32>,
   gpu.return %0 : vector<4xf32>
 }
 
-// LOAD_ND-LABEL:  @no_load_masked(
-// LOAD_ND:        vector.transfer_read
+// LOAD-ND-LABEL:  @no_load_masked(
+// LOAD-ND:        vector.transfer_read
 
-// LOAD_GATHER-LABEL:  @no_load_masked(
-// LOAD_GATHER:        vector.transfer_read
+// LOAD-GATHER-LABEL:  @no_load_masked(
+// LOAD-GATHER:        vector.transfer_read
 }
 
 // -----
@@ -293,11 +293,11 @@ gpu.func @no_load_tensor(%source: tensor<32x64xf32>,
   gpu.return %0 : vector<8x16xf32>
 }
 
-// LOAD_ND-LABEL:  @no_load_tensor(
-// LOAD_ND:        vector.transfer_read
+// LOAD-ND-LABEL:  @no_load_tensor(
+// LOAD-ND:        vector.transfer_read
 
-// LOAD_GATHER-LABEL:  @no_load_tensor(
-// LOAD_GATHER:        vector.transfer_read
+// LOAD-GATHER-LABEL:  @no_load_tensor(
+// LOAD-GATHER:        vector.transfer_read
 }
 
 // -----
@@ -310,23 +310,23 @@ gpu.func @no_load_high_dim_vector(%source: memref<16x32x64xf32>,
   gpu.return %0 : vector<8x16x32xf32>
 }
 
-// LOAD_ND-LABEL:  @no_load_high_dim_vector(
-// LOAD_ND:        vector.transfer_read
-
-// LOAD_GATHER-LABEL:  @no_load_high_dim_vector(
-// LOAD_GATHER:        %[[CST:.+]] = arith.constant dense<true> : vector<8x16x32xi1>
-// LOAD_GATHER:        %[[CST_0:.+]] = arith.constant dense<64> : vector<16xindex>
-// LOAD_GATHER:        %[[CST_1:.+]] = arith.constant dense<2048> : vector<8xindex>
-// LOAD_GATHER:        %[[C2048:.+]] = arith.constant 2048 : index
-// LOAD_GATHER:        %[[C64:.+]] = arith.constant 64 : index
-// LOAD_GATHER-COUNT3: vector.step
-// LOAD_GATHER-COUNT3: vector.shape_cast
-// LOAD_GATHER-COUNT3: vector.broadcast {{.*}} : vector<8x16x32xindex>
-// LOAD_GATHER-COUNT2: arith.addi {{.*}} : vector<8x16x32xindex>
-// LOAD_GATHER:        %[[BCASTOFF:.+]] = vector.broadcast {{.*}} : index to vector<8x16x32xindex>
-// LOAD_GATHER:        %[[IDX:.+]] = arith.addi %[[BCASTOFF]], {{.*}} : vector<8x16x32xindex>
-// LOAD_GATHER:        %[[COLLAPSE:.+]] = memref.collapse_shape %arg0 {{\[}}[0, 1, 2]{{\]}} : memref<16x32x64xf32> into memref<32768xf32>
-// LOAD_GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<32768xf32>, vector<8x16x32xindex>, vector<8x16x32xi1> -> vector<8x16x32xf32>
+// LOAD-ND-LABEL:  @no_load_high_dim_vector(
+// LOAD-ND:        vector.transfer_read
+
+// LOAD-GATHER-LABEL:  @no_load_high_dim_vector(
+// LOAD-GATHER:        %[[CST:.+]] = arith.constant dense<true> : vector<8x16x32xi1>
+// LOAD-GATHER:        %[[CST_0:.+]] = arith.constant dense<64> : vector<16xindex>
+// LOAD-GATHER:        %[[CST_1:.+]] = arith.constant dense<2048> : vector<8xindex>
+// LOAD-GATHER:        %[[C2048:.+]] = arith.constant 2048 : index
+// LOAD-GATHER:        %[[C64:.+]] = arith.constant 64 : index
+// LOAD-GATHER-COUNT3: vector.step
+// LOAD-GATHER-COUNT3: vector.shape_cast
+// LOAD-GATHER-COUNT3: vector.broadcast {{.*}} : vector<8x16x32xindex>
+// LOAD-GATHER-COUNT2: arith.addi {{.*}} : vector<8x16x32xindex>
+// LOAD-GATHER:        %[[BCASTOFF:.+]] = vector.broadcast {{.*}} : index to vector<8x16x32xindex>
+// LOAD-GATHER:        %[[IDX:.+]] = arith.addi %[[BCASTOFF]], {{.*}} : vector<8x16x32xindex>
+// LOAD-GATHER:        %[[COLLAPSE:.+]] = memref.collapse_shape %arg0 {{\[}}[0, 1, 2]{{\]}} : memref<16x32x64xf32> into memref<32768xf32>
+// LOAD-GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<32768xf32>, vector<8x16x32xindex>, vector<8x16x32xi1> -> vector<8x16x32xf32>
 
 }
 
@@ -341,11 +341,11 @@ gpu.func @no_load_non_unit_inner_stride(
   gpu.return %0 : vector<8xf32>
 }
 
-// LOAD_ND-LABEL:  @no_load_non_unit_inner_stride(
-// LOAD_ND:        vector.transfer_read
+// LOAD-ND-LABEL:  @no_load_non_unit_inner_stride(
+// LOAD-ND:        vector.transfer_read
 
-// LOAD_GATHER-LABEL:  @no_load_non_unit_inner_stride(
-// LOAD_GATHER:        vector.transfer_read
+// LOAD-GATHER-LABEL:  @no_load_non_unit_inner_stride(
+// LOAD-GATHER:        vector.transfer_read
 }
 
 
@@ -360,11 +360,11 @@ gpu.func @no_load_unsupported_map(%source: memref<16x32x64xf32>,
   gpu.return %0 : vector<8x16xf32>
 }
 
-// LOAD_ND-LABEL:  @no_load_unsupported_map(
-// LOAD_ND:        vector.transfer_read
+// LOAD-ND-LABEL:  @no_load_unsupported_map(
+// LOAD-ND:        vector.transfer_read
 
-// LOAD_GATHER-LABEL:  @no_load_unsupported_map(
-// LOAD_GATHER:        vector.transfer_read
+// LOAD-GATHER-LABEL:  @no_load_unsupported_map(
+// LOAD-GATHER:        vector.transfer_read
 }
 
 // -----
@@ -378,19 +378,19 @@ gpu.func @no_load_transpose_unsupported_data_type(%source: memref<32x64xf16>,
   gpu.return %0 : vector<8x16xf16>
 }
 
-// LOAD_ND-LABEL:  @no_load_transpose_unsupported_data_type(
-// LOAD_ND:        vector.transfer_read
-
-// LOAD_GATHER-LABEL:  @no_load_transpose_unsupported_data_type(
-// LOAD_GATHER-SAME:    %[[SRC:.+]]: memref<32x64xf16>,
-// LOAD_GATHER:         %[[CST:.+]] = arith.constant dense<true> : vector<8x16xi1>
-// LOAD_GATHER-COUNT2:  vector.step
-// LOAD_GATHER-COUNT2:  vector.shape_cast
-// LOAD_GATHER-COUNT2: vector.broadcast
-// LOAD_GATHER-COUNT2: arith.muli {{.*}} : index
-// LOAD_GATHER-COUNT2: arith.addi {{.*}} : index
-// LOAD_GATHER:        %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
-// LOAD_GATHER:        %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}}: vector<8x16xindex>
-// LOAD_GATHER:        %[[COLLAPSE:.*]] = memref.collapse_shape %arg0 {{\[\[}}0, 1{{\]\]}} : memref<32x64xf16> into memref<2048xf16>
-// LOAD_GATHER:        %[[LOAD:.*]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<2048xf16>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf16>
+// LOAD-ND-LABEL:  @no_load_transpose_unsupported_data_type(
+// LOAD-ND:        vector.transfer_read
+
+// LOAD-GATHER-LABEL:  @no_load_transpose_unsupported_data_type(
+// LOAD-GATHER-SAME:    %[[SRC:.+]]: memref<32x64xf16>,
+// LOAD-GATHER:         %[[CST:.+]] = arith.constant dense<true> : vector<8x16xi1>
+// LOAD-GATHER-COUNT2:  vector.step
+// LOAD-GATHER-COUNT2:  vector.shape_cast
+// LOAD-GATHER-COUNT2: vector.broadcast
+// LOAD-GATHER-COUNT2: arith.muli {{.*}} : index
+// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
+// LOAD-GATHER:        %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
+// LOAD-GATHER:        %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}}: vector<8x16xindex>
+// LOAD-GATHER:        %[[COLLAPSE:.*]] = memref.collapse_shape %arg0 {{\[\[}}0, 1{{\]\]}} : memref<32x64xf16> into memref<2048xf16>
+// LOAD-GATHER:        %[[LOAD:.*]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<2048xf16>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf16>
 }
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
index e244995dd5817..afc37959219d7 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
@@ -1,4 +1,6 @@
-// RUN: mlir-opt %s --xevm-attach-target='module=xevm.* O=3 chip=pvc' -convert-vector-to-xegpu -split-input-file | FileCheck %s
+// RUN: mlir-opt %s --xevm-attach-target='module=xevm.* O=3 chip=pvc' -convert-vector-to-xegpu -split-input-file | FileCheck %s --check-prefix=STORE-ND
+// RUN: mlir-opt %s -convert-vector-to-xegpu -split-input-file | FileCheck %s --check-prefix=STORE-SCATTER
+
 
 gpu.module @xevm_module {
 gpu.func @store_1D_vector(%vec: vector<8xf32>,
@@ -9,15 +11,27 @@ gpu.func @store_1D_vector(%vec: vector<8xf32>,
   gpu.return
 }
 
-// CHECK-LABEL: @store_1D_vector(
-// CHECK-SAME:  %[[VEC:.+]]: vector<8xf32>,
-// CHECK-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
-// CHECK-SAME:  %[[OFFSET:.+]]: index
-// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
-// CHECK-SAME:    memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
-// CHECK-SAME:    boundary_check = false
-// CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8xf32>
+// STORE-ND-LABEL: @store_1D_vector(
+// STORE-ND-SAME:  %[[VEC:.+]]: vector<8xf32>,
+// STORE-ND-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
+// STORE-ND-SAME:  %[[OFFSET:.+]]: index
+// STORE-ND:       %[[DESC:.+]] = xegpu.create_nd_tdesc
+// STORE-ND-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// STORE-ND-SAME:    memref<8x16x32xf32> -> !xegpu.tensor_desc<8xf32,
+// STORE-ND-SAME:    boundary_check = false
+// STORE-ND:       xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8xf32>
+
+// STORE-SCATTER-LABEL:  @store_1D_vector(
+// STORE-SCATTER-SAME:   %[[VEC:.+]]: vector<8xf32>,
+// STORE-SCATTER-SAME:   %[[SRC:.+]]: memref<8x16x32xf32>,
+// STORE-SCATTER-DAG:        %[[CST:.+]] = arith.constant dense<true> : vector<8xi1>
+// STORE-SCATTER-DAG:        %[[STEP:.+]] = vector.step
+// STORE-SCATTER-COUNT2: arith.muli {{.*}} : index
+// STORE-SCATTER-COUNT2: arith.addi {{.*}} : index
+// STORE-SCATTER-DAG:    %[[BCAST:.+]] = vector.broadcast {{.*}} : index to vector<8xindex>
+// STORE-SCATTER-DAG:    %[[IDX:.+]] = arith.addi %[[BCAST]], %{{.*}} : vector<8xindex>
+// STORE-SCATTER-DAG:    %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
+// STORE-SCATTER:       xegpu.store %[[VEC]], %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8xf32>, memref<4096xf32>, vector<8xindex>, vector<8xi1>
 }
 
 // -----
@@ -30,15 +44,28 @@ gpu.func @store_2D_vector(%vec: vector<8x16xf32>,
   gpu.return
 }
 
-// CHECK-LABEL: @store_2D_vector(
-// CHECK-SAME:  %[[VEC:.+]]: vector<8x16xf32>,
-// CHECK-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
-// CHECK-SAME:  %[[OFFSET:.+]]: index
-// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
-// CHECK-SAME:    memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
-// CHECK-SAME:    boundary_check = false
-// CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+// STORE-ND-LABEL: @store_2D_vector(
+// STORE-ND-SAME:  %[[VEC:.+]]: vector<8x16xf32>,
+// STORE-ND-SAME:  %[[SRC:.+]]: memref<8x16x32xf32>,
+// STORE-ND-SAME:  %[[OFFSET:.+]]: index
+// STORE-ND:       %[[DESC:.+]] = xegpu.create_nd_tdesc
+// STORE-ND-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// STORE-ND-SAME:    memref<8x16x32xf32> -> !xegpu.tensor_desc<8x16xf32,
+// STORE-ND-SAME:    boundary_check = false
+// STORE-ND:       xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+
+// STORE-SCATTER-LABEL:  @store_2D_vector(
+// STORE-SCATTER-SAME:   %[[VEC:.+]]: vector<8x16xf32>,
+// STORE-SCATTER-SAME:   %[[SRC:.+]]: memref<8x16x32xf32>,
+// STORE-SCATTER-SAME:   %[[OFFSET:.+]]: index
+// STORE-SCATTER:        %[[CST:.+]] = arith.constant dense<true> : vector<8x16xi1>
+// STORE-SCATTER-COUNT2: %[[STEP:.+]] = vector.step
+// STORE-SCATTER-COUNT2: vector.shape_cast {{.*}}
+// STORE-SCATTER-COUNT2: vector.broadcast {{.*}} : vector<8x16xindex>
+// STORE-SCATTER-DAG:    %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
+// STORE-SCATTER-DAG:    %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}} : vector<8x16xindex>
+// STORE-SCATTER-DAG:    %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<8x16x32xf32> into memref<4096xf32>
+// STORE-SCATTER:        xegpu.store %[[VEC]], %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8x16xf32>, memref<4096xf32>, vector<8x16xindex>, vector<8x16xi1>
 }
 
 // -----
@@ -51,21 +78,34 @@ gpu.func @store_dynamic_source(%vec: vector<8x16xf32>,
   gpu.return
 }
 
-// CHECK-LABEL: @store_dynamic_source(
-// CHECK-SAME:  %[[VEC:.+]]: vector<8x16xf32>,
-// CHECK-SAME:  %[[SRC:.+]]: memref<?x?x?xf32>,
-// CHECK-SAME:  %[[OFFSET:.+]]: index
-// CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG:   %[[C2:.+]] = arith.constant 2 : index
-// CHECK-DAG:   %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
-// CHECK-DAG:   %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
-// CHECK-DAG:   %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
-// CHECK:       %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
-// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
-// CHECK-SAME:  , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
-// CHECK-SAME:    memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32
-// CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+// STORE-ND-LABEL: @store_dynamic_source(
+// STORE-ND-SAME:  %[[VEC:.+]]: vector<8x16xf32>,
+// STORE-ND-SAME:  %[[SRC:.+]]: memref<?x?x?xf32>,
+// STORE-ND-SAME:  %[[OFFSET:.+]]: index
+// STORE-ND-DAG:   %[[C0:.+]] = arith.constant 0 : index
+// STORE-ND-DAG:   %[[C1:.+]] = arith.constant 1 : index
+// STORE-ND-DAG:   %[[C2:.+]] = arith.constant 2 : index
+// STORE-ND-DAG:   %[[DIM_0:.+]] = memref.dim %[[SRC]], %[[C0]]
+// STORE-ND-DAG:   %[[DIM_1:.+]] = memref.dim %[[SRC]], %[[C1]]
+// STORE-ND-DAG:   %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
+// STORE-ND:       %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
+// STORE-ND:       %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
+// STORE-ND-SAME:  , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
+// STORE-ND-SAME:    memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32
+// STORE-ND:       xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+
+// STORE-SCATTER-LABEL: @store_dynamic_source(
+// STORE-SCATTER-SAME:  %[[VEC:.+]]: vector<8x16xf32>,
+// STORE-SCATTER-SAME:  %[[SRC:.+]]: memref<?x?x?xf32>,
+// STORE-SCATTER-DAG:   %[[CST:.+]] = arith.constant dense<true> : vector<8x16xi1>
+// STORE-SCATTER-DAG:   memref.extract_strided_metadata %[[SRC]] : memref<?x?x?xf32> -> memref<f32>, index, index, index, index, index, index, index
+// STORE-SCATTER-COUNT2: %[[STEP:.+]] = vector.step
+// STORE-SCATTER-COUNT2: vector.shape_cast {{.*}}
+// STORE-SCATTER-COUNT2: vector.broadcast {{.*}} : vector<8x16xindex>
+// STORE-SCATTER-DAG:   %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
+// STORE-SCATTER-DAG:   %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}} : vector<8x16xindex>
+// STORE-SCATTER-DAG:   %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<?x?x?xf32> into memref<?xf32>
+// STORE-SCATTER:       xegpu.store %[[VEC]], %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8x16xf32>, memref<?xf32>, vector<8x16xindex>, vector<8x16xi1>
 }
 
 // -----
@@ -78,14 +118,17 @@ gpu.func @store_out_of_bounds(%vec: vector<8x16xf32>,
   gpu.return
 }
 
-// CHECK-LABEL:   @store_out_of_bounds(
-// CHECK-SAME:  %[[VEC:.+]]: vector<8x16xf32>,
-// CHECK-SAME:  %[[SRC:.+]]: memref<7x64xf32>,
-// CHECK-SAME:  %[[OFFSET:.+]]: index
-// CHECK:       %[[DESC:.+]] = xegpu.create_nd_tdesc
-// CHECK-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
-// CHECK-SAME:    memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32>
-// CHECK:       xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+// STORE-ND-LABEL:   @store_out_of_bounds(
+// STORE-ND-SAME:  %[[VEC:.+]]: vector<8x16xf32>,
+// STORE-ND-SAME:  %[[SRC:.+]]: memref<7x64xf32>,
+// STORE-ND-SAME:  %[[OFFSET:.+]]: index
+// STORE-ND:       %[[DESC:.+]] = xegpu.create_nd_tdesc
+// STORE-ND-SAME:    %[[SRC]][%[[OFFSET]], %[[OFFSET]]]
+// STORE-ND-SAME:    memref<7x64xf32> -> !xegpu.tensor_desc<8x16xf32>
+// STORE-ND:       xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
+
+// STORE-SCATTER-LABEL:  @store_out_of_bounds(
+// STORE-SCATTER:   vector.transfer_write
 }
 
 // -----
@@ -99,8 +142,21 @@ gpu.func @no_store_transposed(%vec: vector<8x16xf32>,
   gpu.return
 }
 
-// CHECK-LABEL: @no_store_transposed(
-// CHECK:       vector.transfer_write
+// STORE-ND-LABEL: @no_store_transposed(
+// STORE-ND:       vector.transfer_write
+
+// STORE-SCATTER-LABEL:  @no_store_transposed(
+// STORE-SCATTER-SAME:   %[[VEC:.+]]: vector<8x16xf32>,
+// STORE-SCATTER-SAME:   %[[SRC:.+]]: memref<32x64xf32>,
+// STORE-SCATTER-SAME:   %[[OFFSET:.+]]: index
+// STORE-SCATTER:        %[[CST:.+]] = arith.constant dense<true> : vector<8x16xi1>
+// STORE-SCATTER-COUNT2: %[[STEP:.+]] = vector.step
+// STORE-SCATTER-COUNT2: vector.shape_cast {{.*}}
+// STORE-SCATTER-COUNT2: vector.broadcast {{.*}} : vector<8x16xindex>
+// STORE-SCATTER-DAG:    %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
+// STORE-SCATTER-DAG:    %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}} : vector<8x16xindex>
+// STORE-SCATTER-DAG:    %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1]{{\]}} : memref<32x64xf32> into memref<2048xf32>
+// STORE-SCATTER:        xegpu.store %[[VEC]], %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8x16xf32>, memref<2048xf32>, vector<8x16xindex>, vector<8x16xi1>
 }
 
 // -----
@@ -114,8 +170,11 @@ gpu.func @no_store_masked(%vec: vector<4xf32>,
   gpu.return
 }
 
-// CHECK-LABEL: @no_store_masked(
-// CHECK:       vector.transfer_write
+// STORE-ND-LABEL: @no_store_masked(
+// STORE-ND:       vector.transfer_write
+
+// STORE-SCATTER-LABEL:  @no_store_masked(
+// STORE-SCATTER:        vector.transfer_write
 }
 
 // -----
@@ -128,8 +187,11 @@ gpu.func @no_store_tensor(%vec: vector<8x16xf32>,
   gpu.return %0 : tensor<32x64xf32>
 }
 
-// CHECK-LABEL: @no_store_tensor(
-// CHECK:       vector.transfer_write
+// STORE-ND-LABEL: @no_store_tensor(
+// STORE-ND:       vector.transfer_write
+
+// STORE-SCATTER-LABEL:  @no_store_tensor(
+// STORE-SCATTER:        vector.transfer_write
 }
 
 // -----
@@ -142,8 +204,25 @@ gpu.func @no_store_high_dim_vector(%vec: vector<8x16x32xf32>,
   gpu.return
 }
 
-// CHECK-LABEL: @no_store_high_dim_vector(
-// CHECK:       vector.transfer_write
+// STORE-ND-LABEL: @no_store_high_dim_vector(
+// STORE-ND:       vector.transfer_write
+
+// STORE-SCATTER-LABEL:  @no_store_high_dim_vector(
+// STORE-SCATTER-SAME:   %[[VEC:.+]]: vector<8x16x32xf32>,
+// STORE-SCATTER-SAME:   %[[SRC:.+]]: memref<16x32x64xf32>
+// STORE-SCATTER:        %[[CST:.+]] = arith.constant dense<true> : vector<8x16x32xi1>
+// STORE-SCATTER:        %[[CST_0:.+]] = arith.constant dense<64> : vector<16xindex>
+// STORE-SCATTER:        %[[CST_1:.+]] = arith.constant dense<2048> : vector<8xindex>
+// STORE-SCATTER:        %[[C2048:.+]] = arith.constant 2048 : index
+// STORE-SCATTER:        %[[C64:.+]] = arith.constant 64 : index
+// STORE-SCATTER-COUNT3: vector.step
+// STORE-SCATTER-COUNT3: vector.shape_cast
+// STORE-SCATTER-COUNT3: vector.broadcast {{.*}} : vector<8x16x32xindex>
+// STORE-SCATTER-COUNT2: arith.addi {{.*}} : vector<8x16x32xindex>
+// STORE-SCATTER:        %[[BCASTOFF:.+]] = vector.broadcast {{.*}} : index to vector<8x16x32xindex>
+// STORE-SCATTER:        %[[IDX:.+]] = arith.addi %[[BCASTOFF]], {{.*}} : vector<8x16x32xindex>
+// STORE-SCATTER:        %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<16x32x64xf32> into memref<32768xf32>
+// STORE-SCATTER:        xegpu.store %[[VEC]], %[[COLLAPSE]][%[[IDX]]], %[[CST]] : vector<8x16x32xf32>, memref<32768xf32>, vector<8x16x32xindex>, vector<8x16x32xi1> 
 }
 
 // -----
@@ -156,8 +235,11 @@ gpu.func @no_store_non_unit_inner_stride(%vec: vector<8xf32>,
   gpu.return
 }
 
-// CHECK-LABEL: @no_store_non_unit_inner_stride(
-// CHECK:       vector.transfer_write
+// STORE-ND-LABEL: @no_store_non_unit_inner_stride(
+// STORE-ND:       vector.transfer_write
+
+// STORE-SCATTER-LABEL:  @no_store_non_unit_inner_stride(
+// STORE-SCATTER:        vector.transfer_write
 }
 
 // -----
@@ -171,8 +253,11 @@ gpu.func @no_store_unsupported_map(%vec: vector<8x16xf32>,
   gpu.return
 }
 
-// CHECK-LABEL: @no_store_unsupported_map(
-// CHECK:       vector.transfer_write
+// STORE-ND-LABEL: @no_store_unsupported_map(
+// STORE-ND:       vector.transfer_write
+
+// STORE-SCATTER-LABEL:  @no_store_unsupported_map(
+// STORE-SCATTER:        vector.transfer_write
 }
 
 // -----
@@ -185,6 +270,9 @@ gpu.func @no_store_out_of_bounds_1D_vector(%vec: vector<8xf32>,
   gpu.return
 }
 
-// CHECK-LABEL: @no_store_out_of_bounds_1D_vector(
-// CHECK:       vector.transfer_write
+// STORE-ND-LABEL: @no_store_out_of_bounds_1D_vector(
+// STORE-ND:       vector.transfer_write
+
+// STORE-SCATTER-LABEL:  @no_store_out_of_bounds_1D_vector(
+// STORE-SCATTER:        vector.transfer_write
 }
\ No newline at end of file

>From 35e332234e9629575b32bdc6af7abf7291b85036 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 7 Aug 2025 02:25:03 +0000
Subject: [PATCH 09/17] change comments

---
 .../VectorToXeGPU/VectorToXeGPU.cpp           | 125 +++++++++---------
 1 file changed, 59 insertions(+), 66 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index c59a2060de6a3..d24aaaacdffba 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -152,6 +152,37 @@ createNdDescriptor(PatternRewriter &rewriter, Location loc,
   return ndDesc;
 }
 
+static LogicalResult
+extraCheckForScatteredLoadStore(VectorTransferOpInterface xferOp,
+                                PatternRewriter &rewriter) {
+  // 1. it must be inbound access by checking in_bounds attributes, like
+  // {in_bounds = [false, true]}
+  if (xferOp.hasOutOfBoundsDim())
+    return rewriter.notifyMatchFailure(xferOp,
+                                       "Out-of-bounds access is not supported "
+                                       "for scatter load/store lowering");
+  // 2. if the memref has static shape, its lower rank must exactly match with
+  // vector shape.
+  if (auto memrefType = dyn_cast<MemRefType>(xferOp.getShapedType())) {
+    if (memrefType.hasStaticShape()) {
+      ArrayRef<int64_t> memrefShape = memrefType.getShape();
+      ArrayRef<int64_t> vectorShape = xferOp.getVectorType().getShape();
+      size_t memrefRank = memrefShape.size();
+      size_t vectorRank = vectorShape.size();
+      if (vectorRank > memrefRank)
+        return rewriter.notifyMatchFailure(
+            xferOp, "Vector rank cannot exceed memref rank");
+      // Compare the last vectorRank dimensions of memref with vector shape
+      for (size_t i = 0; i < vectorRank; ++i) {
+        if (memrefShape[memrefRank - vectorRank + i] <= vectorShape[i])
+          return rewriter.notifyMatchFailure(
+              xferOp, "Memref lower dimensions must match vector shape");
+      }
+    }
+  }
+  return success();
+}
+
 static LogicalResult adjustStridesForPermutation(
     Operation *op, PatternRewriter &rewriter, MemRefType memrefType,
     AffineMap permMap, VectorType vecType, SmallVectorImpl<Value> &strides) {
@@ -237,41 +268,34 @@ SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
   return strides;
 }
 
-// This function lowers vector.transfer_read to XeGPU load operation.
-  // Example:
-  //   %0 = vector.transfer_read %expand_shape[%block_id_y, %c0, %c0, %c0, %c0], 
-  //               %cst {in_bounds = [true, true, true, true]}>} : 
-  //               memref<8x4x2x6x32xbf16>, vector<4x2x6x32xbf16>
-  // 
-  //   %6 = vector.step: vector<4xindex> 
-  //   %7 = vector.step: vector<2xindex> 
-  //   %8 = vector.step: vector<6xindex> 
-  //   %9 = vector.step: vector<32xindex> 
-  //   %10 = arith.mul %6, 384
-  //   %11 = arith.mul %7, 192
-  //   %12 = arith.mul %8, 32
-  //   %13 = arith.mul %9, 1
-  //   %14 = vector.shape_cast %10: vector<4xindex> -> vector<4x1x1x1xbf16>
-  //   %15 = vector.shape_cast %11: vector<2xindex> -> vector<1x2x1x1xbf16>
-  //   %16 = vector.shape_cast %12: vector<6xindex> -> vector<1x1x6x1xbf16>
-  //   %17 = vector.shape_cast %13: vector<32xindex> -> vector<1x1x1x32xbf16>
-  //   %18 = vector.broadcast %14: vector<4x1x1x1xbf16> -> vector<4x2x6x32xindex>  
-  //   %19 = vector.broadcast %15: vector<1x2x1x1xbf16> -> vector<4x2x6x32xindex>  
-  //   %20 = vector.broadcast %16: vector<1x1x6x1xbf16> -> vector<4x2x6x32xindex>  
-  //   %21 = vector.broadcast %17: vector<1x1x1x32xbf16> -> vector<4x2x6x32xindex>  
-  //   %22 = arith.add %18, %19
-  //   %23 = arith.add %20, %21
-  //   %local_offsets = arith.add %22, %23
-  //   %orig_offset = %block_id_y * 4x2x6x32 // consider using affine map
-  //   %offsets =  orig_offset + local_offsets
-
-//   %expand_shape1 = memref.collapseshape %expand_shape:
-//   memref<8x4x2x6x32xbf16> -> memref<?bf16>
-
-//   %vec = xegpu.load_gather %expand_shape1[%offsets]:memref<?xbf16>,
-//                           vector<4x2x6x32xindex> -> vector<4x2x6x32xbf16>
-
-// Compute localOffsets for load_gather and store_scatter
+// This function compute the vectors of localOffsets for scattered load/stores.
+// It is used in the lowering of vector.transfer_read/write to
+// load_gather/store_scatter Example:
+//   %0 = vector.transfer_read %expand_shape[%block_id_y, %c0, %c0, %c0, %c0],
+//               %cst {in_bounds = [true, true, true, true]}>} :
+//               memref<8x4x2x6x32xbf16>, vector<4x2x6x32xbf16>
+//
+//   %6 = vector.step: vector<4xindex>
+//   %7 = vector.step: vector<2xindex>
+//   %8 = vector.step: vector<6xindex>
+//   %9 = vector.step: vector<32xindex>
+//   %10 = arith.mul %6, 384
+//   %11 = arith.mul %7, 192
+//   %12 = arith.mul %8, 32
+//   %13 = arith.mul %9, 1
+//   %14 = vector.shape_cast %10: vector<4xindex> -> vector<4x1x1x1xbf16>
+//   %15 = vector.shape_cast %11: vector<2xindex> -> vector<1x2x1x1xbf16>
+//   %16 = vector.shape_cast %12: vector<6xindex> -> vector<1x1x6x1xbf16>
+//   %17 = vector.shape_cast %13: vector<32xindex> -> vector<1x1x1x32xbf16>
+//   %18 = vector.broadcast %14: vector<4x1x1x1xbf16> -> vector<4x2x6x32xindex>
+//   %19 = vector.broadcast %15: vector<1x2x1x1xbf16> -> vector<4x2x6x32xindex>
+//   %20 = vector.broadcast %16: vector<1x1x6x1xbf16> -> vector<4x2x6x32xindex>
+//   %21 = vector.broadcast %17: vector<1x1x1x32xbf16> -> vector<4x2x6x32xindex>
+//   %22 = arith.add %18, %19
+//   %23 = arith.add %20, %21
+//   %local_offsets = arith.add %22, %23
+//   %orig_offset = %block_id_y * 4x2x6x32 // consider using affine map
+//   %offsets =  orig_offset + local_offsets
 static Value computeGatherOffsets(VectorTransferOpInterface xferOp,
                                   PatternRewriter &rewriter,
                                   ArrayRef<Value> strides) {
@@ -462,37 +486,6 @@ LogicalResult lowerTransferWriteToStoreOp(vector::TransferWriteOp writeOp,
                             localOffsets);
 }
 
-static LogicalResult
-extraCheckForScatteredLoadStore(VectorTransferOpInterface xferOp,
-                                PatternRewriter &rewriter) {
-  // 1. it must be inbound access by checking in_bounds attributes, like
-  // {in_bounds = [false, true]}
-  if (xferOp.hasOutOfBoundsDim())
-    return rewriter.notifyMatchFailure(xferOp,
-                                       "Out-of-bounds access is not supported "
-                                       "for scatter load/store lowering");
-  // 2. if the memref has static shape, its lower rank must exactly match with
-  // vector shape.
-  if (auto memrefType = dyn_cast<MemRefType>(xferOp.getShapedType())) {
-    if (memrefType.hasStaticShape()) {
-      ArrayRef<int64_t> memrefShape = memrefType.getShape();
-      ArrayRef<int64_t> vectorShape = xferOp.getVectorType().getShape();
-      size_t memrefRank = memrefShape.size();
-      size_t vectorRank = vectorShape.size();
-      if (vectorRank > memrefRank)
-        return rewriter.notifyMatchFailure(
-            xferOp, "Vector rank cannot exceed memref rank");
-      // Compare the last vectorRank dimensions of memref with vector shape
-      for (size_t i = 0; i < vectorRank; ++i) {
-        if (memrefShape[memrefRank - vectorRank + i] <= vectorShape[i])
-          return rewriter.notifyMatchFailure(
-              xferOp, "Memref lower dimensions must match vector shape");
-      }
-    }
-  }
-  return success();
-}
-
 struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
   using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
 

>From c86b80836686952feb272261db89a41074e2d0f0 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 7 Aug 2025 02:43:08 +0000
Subject: [PATCH 10/17] change function name

---
 mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index d24aaaacdffba..9f8fae841597a 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -296,9 +296,9 @@ SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
 //   %local_offsets = arith.add %22, %23
 //   %orig_offset = %block_id_y * 4x2x6x32 // consider using affine map
 //   %offsets =  orig_offset + local_offsets
-static Value computeGatherOffsets(VectorTransferOpInterface xferOp,
-                                  PatternRewriter &rewriter,
-                                  ArrayRef<Value> strides) {
+static Value computeOffsets(VectorTransferOpInterface xferOp,
+                            PatternRewriter &rewriter,
+                            ArrayRef<Value> strides) {
   Location loc = xferOp.getLoc();
   VectorType vectorType = xferOp.getVectorType();
   SmallVector<Value> indices(xferOp.getIndices().begin(),
@@ -462,7 +462,7 @@ LogicalResult lowerTransferReadToLoadOp(vector::TransferReadOp readOp,
   if (strides.empty())
     return rewriter.notifyMatchFailure(readOp, "Failed to compute strides");
 
-  Value localOffsets = computeGatherOffsets(readOp, rewriter, strides);
+  Value localOffsets = computeOffsets(readOp, rewriter, strides);
 
   Value flatMemref = collapseMemrefTo1D(readOp, rewriter);
 
@@ -478,7 +478,7 @@ LogicalResult lowerTransferWriteToStoreOp(vector::TransferWriteOp writeOp,
 
   SmallVector<Value> strides = computeStrides(writeOp, rewriter);
 
-  Value localOffsets = computeGatherOffsets(writeOp, rewriter, strides);
+  Value localOffsets = computeOffsets(writeOp, rewriter, strides);
 
   Value flatMemref = collapseMemrefTo1D(writeOp, rewriter);
 

>From a51061e09fd086f018bf44407db0aff7bb5855d7 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Mon, 11 Aug 2025 23:55:17 +0000
Subject: [PATCH 11/17] modernize op creation

---
 .../VectorToXeGPU/VectorToXeGPU.cpp           | 67 ++++++++++---------
 1 file changed, 35 insertions(+), 32 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 9f8fae841597a..2166e2110878d 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -236,7 +236,7 @@ SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
     }
     // Wrap static strides as MLIR values
     for (int64_t s : intStrides)
-      strides.push_back(rewriter.create<arith::ConstantIndexOp>(loc, s));
+      strides.push_back(arith::ConstantIndexOp::create(rewriter, loc, s));
   } else {
     // For dynamic shape memref, use memref.extract_strided_metadata to get
     // stride values
@@ -256,8 +256,8 @@ SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
       resultTypes.push_back(indexType); // sizes
     }
 
-    auto meta = rewriter.create<memref::ExtractStridedMetadataOp>(
-        loc, resultTypes, baseMemref);
+    auto meta = memref::ExtractStridedMetadataOp::create(
+        rewriter, loc, resultTypes, baseMemref);
     strides.append(meta.getStrides().begin(), meta.getStrides().end());
   }
   // Adjust strides according to the permutation map (e.g., for transpose)
@@ -309,7 +309,7 @@ static Value computeOffsets(VectorTransferOpInterface xferOp,
   SmallVector<Value> stepVectors;
   for (int64_t dim : vectorShape) {
     auto stepType = VectorType::get({dim}, rewriter.getIndexType());
-    auto stepOp = rewriter.create<vector::StepOp>(loc, stepType);
+    auto stepOp = vector::StepOp::create(rewriter, loc, stepType);
     stepVectors.push_back(stepOp);
   }
 
@@ -321,9 +321,9 @@ static Value computeOffsets(VectorTransferOpInterface xferOp,
     size_t memrefDim = memrefRank - vectorRank + i;
     Value strideValue = strides[memrefDim];
     auto mulType = llvm::cast<VectorType>(stepVectors[i].getType());
-    auto mulOp = rewriter.create<arith::MulIOp>(
-        loc, stepVectors[i],
-        rewriter.create<vector::BroadcastOp>(loc, mulType, strideValue));
+    auto bcastOp =
+        vector::BroadcastOp::create(rewriter, loc, mulType, strideValue);
+    auto mulOp = arith::MulIOp::create(rewriter, loc, stepVectors[i], bcastOp);
     strideMultiplied.push_back(mulOp);
   }
 
@@ -333,8 +333,8 @@ static Value computeOffsets(VectorTransferOpInterface xferOp,
     SmallVector<int64_t> newShape(vectorRank, 1);
     newShape[i] = vectorShape[i];
     auto newType = VectorType::get(newShape, rewriter.getIndexType());
-    auto castOp =
-        rewriter.create<vector::ShapeCastOp>(loc, newType, strideMultiplied[i]);
+    auto castOp = vector::ShapeCastOp::create(rewriter, loc, newType,
+                                              strideMultiplied[i]);
     shapeCasted.push_back(castOp);
   }
 
@@ -343,8 +343,8 @@ static Value computeOffsets(VectorTransferOpInterface xferOp,
   auto fullIndexVectorType =
       VectorType::get(vectorShape, rewriter.getIndexType());
   for (Value shapeCastVal : shapeCasted) {
-    auto broadcastOp = rewriter.create<vector::BroadcastOp>(
-        loc, fullIndexVectorType, shapeCastVal);
+    auto broadcastOp = vector::BroadcastOp::create(
+        rewriter, loc, fullIndexVectorType, shapeCastVal);
     broadcasted.push_back(broadcastOp);
   }
 
@@ -352,24 +352,25 @@ static Value computeOffsets(VectorTransferOpInterface xferOp,
   Value localOffsets = broadcasted[0];
   for (size_t i = 1; i < broadcasted.size(); ++i) {
     localOffsets =
-        rewriter.create<arith::AddIOp>(loc, localOffsets, broadcasted[i]);
+        arith::AddIOp::create(rewriter, loc, localOffsets, broadcasted[i]);
   }
 
   // Step 6: Compute base offset from transfer read indices
   Value baseOffset = nullptr;
   if (!indices.empty()) {
-    baseOffset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    baseOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
     for (size_t i = 0; i < indices.size(); ++i) {
       Value strideVal = strides[i];
       Value offsetContrib =
-          rewriter.create<arith::MulIOp>(loc, indices[i], strideVal);
+          arith::MulIOp::create(rewriter, loc, indices[i], strideVal);
       baseOffset =
-          rewriter.create<arith::AddIOp>(loc, baseOffset, offsetContrib);
+          arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
     }
     // Broadcast base offset to match vector shape
-    Value bcastBase = rewriter.create<vector::BroadcastOp>(
-        loc, fullIndexVectorType, baseOffset);
-    localOffsets = rewriter.create<arith::AddIOp>(loc, bcastBase, localOffsets);
+    Value bcastBase = vector::BroadcastOp::create(
+        rewriter, loc, fullIndexVectorType, baseOffset);
+    localOffsets =
+        arith::AddIOp::create(rewriter, loc, bcastBase, localOffsets);
   }
   return localOffsets;
 }
@@ -408,8 +409,8 @@ static Value collapseMemrefTo1D(VectorTransferOpInterface xferOp,
   }
   reassociation.push_back(allDims);
 
-  auto collapseOp = rewriter.create<memref::CollapseShapeOp>(
-      loc, flatMemrefType, baseMemref, reassociation);
+  auto collapseOp = memref::CollapseShapeOp::create(
+      rewriter, loc, flatMemrefType, baseMemref, reassociation);
   return collapseOp;
 }
 
@@ -420,10 +421,11 @@ static LogicalResult createLoadGather(vector::TransferReadOp readOp,
   Location loc = readOp.getLoc();
   VectorType vectorType = readOp.getVectorType();
   ArrayRef<int64_t> vectorShape = vectorType.getShape();
-  Value mask = rewriter.create<vector::ConstantMaskOp>(
-      loc, VectorType::get(vectorShape, rewriter.getI1Type()), vectorShape);
-  auto gatherOp = rewriter.create<xegpu::LoadGatherOp>(
-      loc, vectorType, flatMemref, localOffsets, mask,
+  Value mask = vector::ConstantMaskOp::create(
+      rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
+      vectorShape);
+  auto gatherOp = xegpu::LoadGatherOp::create(
+      rewriter, loc, vectorType, flatMemref, localOffsets, mask,
       /*chunk_size=*/IntegerAttr{},
       /*l1_hint=*/xegpu::CachePolicyAttr{},
       /*l2_hint=*/xegpu::CachePolicyAttr{},
@@ -439,14 +441,15 @@ static LogicalResult createStoreScatter(vector::TransferWriteOp writeOp,
   Location loc = writeOp.getLoc();
   VectorType vectorType = writeOp.getVectorType();
   ArrayRef<int64_t> vectorShape = vectorType.getShape();
-  Value mask = rewriter.create<vector::ConstantMaskOp>(
-      loc, VectorType::get(vectorShape, rewriter.getI1Type()), vectorShape);
-  rewriter.create<xegpu::StoreScatterOp>(loc, value, flatMemref, localOffsets,
-                                         mask,
-                                         /*chunk_size=*/IntegerAttr{},
-                                         /*l1_hint=*/xegpu::CachePolicyAttr{},
-                                         /*l2_hint=*/xegpu::CachePolicyAttr{},
-                                         /*l3_hint=*/xegpu::CachePolicyAttr{});
+  Value mask = vector::ConstantMaskOp::create(
+      rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
+      vectorShape);
+  xegpu::StoreScatterOp::create(rewriter, loc, value, flatMemref, localOffsets,
+                                mask,
+                                /*chunk_size=*/IntegerAttr{},
+                                /*l1_hint=*/xegpu::CachePolicyAttr{},
+                                /*l2_hint=*/xegpu::CachePolicyAttr{},
+                                /*l3_hint=*/xegpu::CachePolicyAttr{});
   rewriter.eraseOp(writeOp);
   return success();
 }

>From cb0628ef588836517b54496d32a4d306418a028a Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 12 Aug 2025 00:57:33 +0000
Subject: [PATCH 12/17] address comments

---
 .../VectorToXeGPU/VectorToXeGPU.cpp           |  18 +--
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp   |  21 ++--
 .../VectorToXeGPU/transfer-read-to-xegpu.mlir | 114 +++++++++---------
 .../transfer-write-to-xegpu.mlir              |  62 +++++-----
 4 files changed, 111 insertions(+), 104 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 2166e2110878d..c64cccab3b261 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -205,9 +205,9 @@ static LogicalResult adjustStridesForPermutation(
     }
     unsigned pos = dimExpr.getPosition();
     // Map permutation to the relevant strides (innermost dims)
-    if (pos < memrefRank - vecRank) {
+    if (pos < memrefRank - vecRank)
       return rewriter.notifyMatchFailure(op, "Permutation out of bounds");
-    }
+
     // The stride for output dimension outIdx is the stride of input dimension
     // pos
     adjustedStrides[outIdx] = relevantStrides[pos - (memrefRank - vecRank)];
@@ -305,7 +305,7 @@ static Value computeOffsets(VectorTransferOpInterface xferOp,
                              xferOp.getIndices().end());
   ArrayRef<int64_t> vectorShape = vectorType.getShape();
 
-  // Step 1: Create vector.step operations for each dimension
+  // Create vector.step operations for each dimension
   SmallVector<Value> stepVectors;
   for (int64_t dim : vectorShape) {
     auto stepType = VectorType::get({dim}, rewriter.getIndexType());
@@ -313,7 +313,7 @@ static Value computeOffsets(VectorTransferOpInterface xferOp,
     stepVectors.push_back(stepOp);
   }
 
-  // Step 2: Multiply step vectors by corresponding strides
+  // Multiply step vectors by corresponding strides
   size_t memrefRank = strides.size();
   size_t vectorRank = vectorShape.size();
   SmallVector<Value> strideMultiplied;
@@ -327,7 +327,7 @@ static Value computeOffsets(VectorTransferOpInterface xferOp,
     strideMultiplied.push_back(mulOp);
   }
 
-  // Step 3: Shape cast each multiplied vector to add singleton dimensions
+  // Shape cast each multiplied vector to add singleton dimensions
   SmallVector<Value> shapeCasted;
   for (size_t i = 0; i < vectorRank; ++i) {
     SmallVector<int64_t> newShape(vectorRank, 1);
@@ -338,7 +338,7 @@ static Value computeOffsets(VectorTransferOpInterface xferOp,
     shapeCasted.push_back(castOp);
   }
 
-  // Step 4: Broadcast each shape-casted vector to full vector shape
+  // Broadcast each shape-casted vector to full vector shape
   SmallVector<Value> broadcasted;
   auto fullIndexVectorType =
       VectorType::get(vectorShape, rewriter.getIndexType());
@@ -348,14 +348,14 @@ static Value computeOffsets(VectorTransferOpInterface xferOp,
     broadcasted.push_back(broadcastOp);
   }
 
-  // Step 5: Add all broadcasted vectors together to compute local offsets
+  // Add all broadcasted vectors together to compute local offsets
   Value localOffsets = broadcasted[0];
   for (size_t i = 1; i < broadcasted.size(); ++i) {
     localOffsets =
         arith::AddIOp::create(rewriter, loc, localOffsets, broadcasted[i]);
   }
 
-  // Step 6: Compute base offset from transfer read indices
+  // Compute base offset from transfer read indices
   Value baseOffset = nullptr;
   if (!indices.empty()) {
     baseOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
@@ -500,6 +500,7 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
       return failure();
 
     auto chip = xegpu::getXeGPUChipStr(readOp);
+    // TODO:This check needs to be replaced with proper uArch capability check
     if ( chip != "pvc" && chip != "bmg") {
       // perform additional checks -
       if (failed(extraCheckForScatteredLoadStore(readOp, rewriter)))
@@ -569,6 +570,7 @@ struct TransferWriteLowering
       return failure();
 
     auto chip = xegpu::getXeGPUChipStr(writeOp);
+    // TODO:This check needs to be replaced with proper uArch capability check
     if (chip != "pvc" && chip != "bmg") {
       // perform additional checks -
       if (failed(extraCheckForScatteredLoadStore(writeOp, rewriter)))
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 6f0b02897d271..1f090952894d2 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -408,16 +408,19 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(
 }
 
 std::optional<std::string> xegpu::getXeGPUChipStr(Operation *op) {
-  auto gpuModuleOp = op->getParentOfType<mlir::gpu::GPUModuleOp>();
-  if (gpuModuleOp) {
-    auto targetAttrs = gpuModuleOp.getTargets();
-    if (targetAttrs) {
-      for (auto &attr : *targetAttrs) {
-        auto xevmAttr = llvm::dyn_cast<mlir::xevm::XeVMTargetAttr>(attr);
-        if (xevmAttr)
-          return xevmAttr.getChip().str();
-      }
+  auto gpuModuleOp = op->getParentOfType<gpu::GPUModuleOp>();
+
+  if (!gpuModuleOp)
+    return std::nullopt;
+
+  auto targetAttrs = gpuModuleOp.getTargets();
+  if (targetAttrs) {
+    for (auto &attr : *targetAttrs) {
+      auto xevmAttr = llvm::dyn_cast<xevm::XeVMTargetAttr>(attr);
+      if (xevmAttr)
+        return xevmAttr.getChip().str();
     }
   }
+
   return std::nullopt;
 }
\ No newline at end of file
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
index 33228b2b3c4e2..b373bdab80567 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir
@@ -229,6 +229,64 @@ gpu.func @load_dynamic_source3(%source: memref<?x?x?x?x?xf32>,
 // LOAD-GATHER:        return %[[VEC]]
 }
 
+// -----
+gpu.module @xevm_module {
+gpu.func @load_high_dim_vector(%source: memref<16x32x64xf32>,
+    %offset: index, %arg2: index) -> vector<8x16x32xf32> {
+  %c0 = arith.constant 0.0 : f32
+  %0 = vector.transfer_read %source[%offset, %arg2, %offset], %c0
+    {in_bounds = [true, true, true]} : memref<16x32x64xf32>, vector<8x16x32xf32>
+  gpu.return %0 : vector<8x16x32xf32>
+}
+
+// LOAD-ND-LABEL:  @load_high_dim_vector(
+// LOAD-ND:        vector.transfer_read
+
+// LOAD-GATHER-LABEL:  @load_high_dim_vector(
+// LOAD-GATHER:        %[[CST:.+]] = arith.constant dense<true> : vector<8x16x32xi1>
+// LOAD-GATHER:        %[[CST_0:.+]] = arith.constant dense<64> : vector<16xindex>
+// LOAD-GATHER:        %[[CST_1:.+]] = arith.constant dense<2048> : vector<8xindex>
+// LOAD-GATHER:        %[[C2048:.+]] = arith.constant 2048 : index
+// LOAD-GATHER:        %[[C64:.+]] = arith.constant 64 : index
+// LOAD-GATHER-COUNT3: vector.step
+// LOAD-GATHER-COUNT3: vector.shape_cast
+// LOAD-GATHER-COUNT3: vector.broadcast {{.*}} : vector<8x16x32xindex>
+// LOAD-GATHER-COUNT2: arith.addi {{.*}} : vector<8x16x32xindex>
+// LOAD-GATHER:        %[[BCASTOFF:.+]] = vector.broadcast {{.*}} : index to vector<8x16x32xindex>
+// LOAD-GATHER:        %[[IDX:.+]] = arith.addi %[[BCASTOFF]], {{.*}} : vector<8x16x32xindex>
+// LOAD-GATHER:        %[[COLLAPSE:.+]] = memref.collapse_shape %arg0 {{\[}}[0, 1, 2]{{\]}} : memref<16x32x64xf32> into memref<32768xf32>
+// LOAD-GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<32768xf32>, vector<8x16x32xindex>, vector<8x16x32xi1> -> vector<8x16x32xf32>
+
+}
+
+// -----
+gpu.module @xevm_module {
+gpu.func @load_transpose_f16(%source: memref<32x64xf16>,
+    %offset: index) -> vector<8x16xf16> {
+  %c0 = arith.constant 0.0 : f16
+  %0 = vector.transfer_read %source[%offset, %offset], %c0
+    {permutation_map = affine_map<(d0, d1) -> (d1, d0)>,
+    in_bounds = [true, true]} : memref<32x64xf16>, vector<8x16xf16>
+  gpu.return %0 : vector<8x16xf16>
+}
+
+// LOAD-ND-LABEL:  @load_transpose_f16(
+// LOAD-ND:        vector.transfer_read
+
+// LOAD-GATHER-LABEL:  @load_transpose_f16(
+// LOAD-GATHER-SAME:    %[[SRC:.+]]: memref<32x64xf16>,
+// LOAD-GATHER:         %[[CST:.+]] = arith.constant dense<true> : vector<8x16xi1>
+// LOAD-GATHER-COUNT2:  vector.step
+// LOAD-GATHER-COUNT2:  vector.shape_cast
+// LOAD-GATHER-COUNT2: vector.broadcast
+// LOAD-GATHER-COUNT2: arith.muli {{.*}} : index
+// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
+// LOAD-GATHER:        %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
+// LOAD-GATHER:        %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}}: vector<8x16xindex>
+// LOAD-GATHER:        %[[COLLAPSE:.*]] = memref.collapse_shape %arg0 {{\[\[}}0, 1{{\]\]}} : memref<32x64xf16> into memref<2048xf16>
+// LOAD-GATHER:        %[[LOAD:.*]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<2048xf16>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf16>
+}
+
 // -----
 gpu.module @xevm_module {
 gpu.func @no_load_out_of_bounds_non_zero_pad(%source: memref<32x64xf32>,
@@ -300,35 +358,6 @@ gpu.func @no_load_tensor(%source: tensor<32x64xf32>,
 // LOAD-GATHER:        vector.transfer_read
 }
 
-// -----
-gpu.module @xevm_module {
-gpu.func @no_load_high_dim_vector(%source: memref<16x32x64xf32>,
-    %offset: index, %arg2: index) -> vector<8x16x32xf32> {
-  %c0 = arith.constant 0.0 : f32
-  %0 = vector.transfer_read %source[%offset, %arg2, %offset], %c0
-    {in_bounds = [true, true, true]} : memref<16x32x64xf32>, vector<8x16x32xf32>
-  gpu.return %0 : vector<8x16x32xf32>
-}
-
-// LOAD-ND-LABEL:  @no_load_high_dim_vector(
-// LOAD-ND:        vector.transfer_read
-
-// LOAD-GATHER-LABEL:  @no_load_high_dim_vector(
-// LOAD-GATHER:        %[[CST:.+]] = arith.constant dense<true> : vector<8x16x32xi1>
-// LOAD-GATHER:        %[[CST_0:.+]] = arith.constant dense<64> : vector<16xindex>
-// LOAD-GATHER:        %[[CST_1:.+]] = arith.constant dense<2048> : vector<8xindex>
-// LOAD-GATHER:        %[[C2048:.+]] = arith.constant 2048 : index
-// LOAD-GATHER:        %[[C64:.+]] = arith.constant 64 : index
-// LOAD-GATHER-COUNT3: vector.step
-// LOAD-GATHER-COUNT3: vector.shape_cast
-// LOAD-GATHER-COUNT3: vector.broadcast {{.*}} : vector<8x16x32xindex>
-// LOAD-GATHER-COUNT2: arith.addi {{.*}} : vector<8x16x32xindex>
-// LOAD-GATHER:        %[[BCASTOFF:.+]] = vector.broadcast {{.*}} : index to vector<8x16x32xindex>
-// LOAD-GATHER:        %[[IDX:.+]] = arith.addi %[[BCASTOFF]], {{.*}} : vector<8x16x32xindex>
-// LOAD-GATHER:        %[[COLLAPSE:.+]] = memref.collapse_shape %arg0 {{\[}}[0, 1, 2]{{\]}} : memref<16x32x64xf32> into memref<32768xf32>
-// LOAD-GATHER:        %[[VEC:.+]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<32768xf32>, vector<8x16x32xindex>, vector<8x16x32xi1> -> vector<8x16x32xf32>
-
-}
 
 // -----
 gpu.module @xevm_module {
@@ -367,30 +396,3 @@ gpu.func @no_load_unsupported_map(%source: memref<16x32x64xf32>,
 // LOAD-GATHER:        vector.transfer_read
 }
 
-// -----
-gpu.module @xevm_module {
-gpu.func @no_load_transpose_unsupported_data_type(%source: memref<32x64xf16>,
-    %offset: index) -> vector<8x16xf16> {
-  %c0 = arith.constant 0.0 : f16
-  %0 = vector.transfer_read %source[%offset, %offset], %c0
-    {permutation_map = affine_map<(d0, d1) -> (d1, d0)>,
-    in_bounds = [true, true]} : memref<32x64xf16>, vector<8x16xf16>
-  gpu.return %0 : vector<8x16xf16>
-}
-
-// LOAD-ND-LABEL:  @no_load_transpose_unsupported_data_type(
-// LOAD-ND:        vector.transfer_read
-
-// LOAD-GATHER-LABEL:  @no_load_transpose_unsupported_data_type(
-// LOAD-GATHER-SAME:    %[[SRC:.+]]: memref<32x64xf16>,
-// LOAD-GATHER:         %[[CST:.+]] = arith.constant dense<true> : vector<8x16xi1>
-// LOAD-GATHER-COUNT2:  vector.step
-// LOAD-GATHER-COUNT2:  vector.shape_cast
-// LOAD-GATHER-COUNT2: vector.broadcast
-// LOAD-GATHER-COUNT2: arith.muli {{.*}} : index
-// LOAD-GATHER-COUNT2: arith.addi {{.*}} : index
-// LOAD-GATHER:        %[[BCAST2:.+]] = vector.broadcast {{.*}} : index to vector<8x16xindex>
-// LOAD-GATHER:        %[[IDX:.+]] = arith.addi %[[BCAST2]], {{.*}}: vector<8x16xindex>
-// LOAD-GATHER:        %[[COLLAPSE:.*]] = memref.collapse_shape %arg0 {{\[\[}}0, 1{{\]\]}} : memref<32x64xf16> into memref<2048xf16>
-// LOAD-GATHER:        %[[LOAD:.*]] = xegpu.load %[[COLLAPSE]][%[[IDX]]], %[[CST]] : memref<2048xf16>, vector<8x16xindex>, vector<8x16xi1> -> vector<8x16xf16>
-}
diff --git a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
index afc37959219d7..b3f761a545ee1 100644
--- a/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
+++ b/mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir
@@ -159,6 +159,37 @@ gpu.func @no_store_transposed(%vec: vector<8x16xf32>,
 // STORE-SCATTER:        xegpu.store %[[VEC]], %[[COLLAPSE]]{{\[}}%[[IDX]]{{\]}}, %[[CST]] : vector<8x16xf32>, memref<2048xf32>, vector<8x16xindex>, vector<8x16xi1>
 }
 
+// -----
+gpu.module @xevm_module {
+gpu.func @store_high_dim_vector(%vec: vector<8x16x32xf32>,
+    %source: memref<16x32x64xf32>, %offset: index) {
+  vector.transfer_write %vec, %source[%offset, %offset, %offset]
+    {in_bounds = [true, true, true]}
+    : vector<8x16x32xf32>, memref<16x32x64xf32>
+  gpu.return
+}
+
+// STORE-ND-LABEL: @store_high_dim_vector(
+// STORE-ND:       vector.transfer_write
+
+// STORE-SCATTER-LABEL:  @store_high_dim_vector(
+// STORE-SCATTER-SAME:   %[[VEC:.+]]: vector<8x16x32xf32>,
+// STORE-SCATTER-SAME:   %[[SRC:.+]]: memref<16x32x64xf32>
+// STORE-SCATTER:        %[[CST:.+]] = arith.constant dense<true> : vector<8x16x32xi1>
+// STORE-SCATTER:        %[[CST_0:.+]] = arith.constant dense<64> : vector<16xindex>
+// STORE-SCATTER:        %[[CST_1:.+]] = arith.constant dense<2048> : vector<8xindex>
+// STORE-SCATTER:        %[[C2048:.+]] = arith.constant 2048 : index
+// STORE-SCATTER:        %[[C64:.+]] = arith.constant 64 : index
+// STORE-SCATTER-COUNT3: vector.step
+// STORE-SCATTER-COUNT3: vector.shape_cast
+// STORE-SCATTER-COUNT3: vector.broadcast {{.*}} : vector<8x16x32xindex>
+// STORE-SCATTER-COUNT2: arith.addi {{.*}} : vector<8x16x32xindex>
+// STORE-SCATTER:        %[[BCASTOFF:.+]] = vector.broadcast {{.*}} : index to vector<8x16x32xindex>
+// STORE-SCATTER:        %[[IDX:.+]] = arith.addi %[[BCASTOFF]], {{.*}} : vector<8x16x32xindex>
+// STORE-SCATTER:        %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<16x32x64xf32> into memref<32768xf32>
+// STORE-SCATTER:        xegpu.store %[[VEC]], %[[COLLAPSE]][%[[IDX]]], %[[CST]] : vector<8x16x32xf32>, memref<32768xf32>, vector<8x16x32xindex>, vector<8x16x32xi1> 
+}
+
 // -----
 gpu.module @xevm_module {
 gpu.func @no_store_masked(%vec: vector<4xf32>,
@@ -194,37 +225,6 @@ gpu.func @no_store_tensor(%vec: vector<8x16xf32>,
 // STORE-SCATTER:        vector.transfer_write
 }
 
-// -----
-gpu.module @xevm_module {
-gpu.func @no_store_high_dim_vector(%vec: vector<8x16x32xf32>,
-    %source: memref<16x32x64xf32>, %offset: index) {
-  vector.transfer_write %vec, %source[%offset, %offset, %offset]
-    {in_bounds = [true, true, true]}
-    : vector<8x16x32xf32>, memref<16x32x64xf32>
-  gpu.return
-}
-
-// STORE-ND-LABEL: @no_store_high_dim_vector(
-// STORE-ND:       vector.transfer_write
-
-// STORE-SCATTER-LABEL:  @no_store_high_dim_vector(
-// STORE-SCATTER-SAME:   %[[VEC:.+]]: vector<8x16x32xf32>,
-// STORE-SCATTER-SAME:   %[[SRC:.+]]: memref<16x32x64xf32>
-// STORE-SCATTER:        %[[CST:.+]] = arith.constant dense<true> : vector<8x16x32xi1>
-// STORE-SCATTER:        %[[CST_0:.+]] = arith.constant dense<64> : vector<16xindex>
-// STORE-SCATTER:        %[[CST_1:.+]] = arith.constant dense<2048> : vector<8xindex>
-// STORE-SCATTER:        %[[C2048:.+]] = arith.constant 2048 : index
-// STORE-SCATTER:        %[[C64:.+]] = arith.constant 64 : index
-// STORE-SCATTER-COUNT3: vector.step
-// STORE-SCATTER-COUNT3: vector.shape_cast
-// STORE-SCATTER-COUNT3: vector.broadcast {{.*}} : vector<8x16x32xindex>
-// STORE-SCATTER-COUNT2: arith.addi {{.*}} : vector<8x16x32xindex>
-// STORE-SCATTER:        %[[BCASTOFF:.+]] = vector.broadcast {{.*}} : index to vector<8x16x32xindex>
-// STORE-SCATTER:        %[[IDX:.+]] = arith.addi %[[BCASTOFF]], {{.*}} : vector<8x16x32xindex>
-// STORE-SCATTER:        %[[COLLAPSE:.+]] = memref.collapse_shape %[[SRC]] {{\[}}[0, 1, 2]{{\]}} : memref<16x32x64xf32> into memref<32768xf32>
-// STORE-SCATTER:        xegpu.store %[[VEC]], %[[COLLAPSE]][%[[IDX]]], %[[CST]] : vector<8x16x32xf32>, memref<32768xf32>, vector<8x16x32xindex>, vector<8x16x32xi1> 
-}
-
 // -----
 gpu.module @xevm_module {
 gpu.func @no_store_non_unit_inner_stride(%vec: vector<8xf32>,

>From 77fae691d0934cf0532d27f725f95c7e05965d37 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 12 Aug 2025 17:49:00 +0000
Subject: [PATCH 13/17] address feedback

---
 .../VectorToXeGPU/VectorToXeGPU.cpp           | 119 +++++++-----------
 1 file changed, 48 insertions(+), 71 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index c64cccab3b261..84871cfccf09d 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -174,7 +174,7 @@ extraCheckForScatteredLoadStore(VectorTransferOpInterface xferOp,
             xferOp, "Vector rank cannot exceed memref rank");
       // Compare the last vectorRank dimensions of memref with vector shape
       for (size_t i = 0; i < vectorRank; ++i) {
-        if (memrefShape[memrefRank - vectorRank + i] <= vectorShape[i])
+        if (vectorShape[i] > memrefShape[memrefRank - vectorRank + i])
           return rewriter.notifyMatchFailure(
               xferOp, "Memref lower dimensions must match vector shape");
       }
@@ -225,15 +225,14 @@ SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
   Value baseMemref = xferOp.getBase();
   AffineMap permMap = xferOp.getPermutationMap();
   VectorType vectorType = xferOp.getVectorType();
-  MemRefType memrefType = llvm::cast<MemRefType>(baseMemref.getType());
+  MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType());
 
   Location loc = xferOp.getLoc();
   if (memrefType.hasStaticShape()) {
     int64_t offset;
     SmallVector<int64_t> intStrides;
-    if (failed(memrefType.getStridesAndOffset(intStrides, offset))) {
+    if (failed(memrefType.getStridesAndOffset(intStrides, offset)))
       return {};
-    }
     // Wrap static strides as MLIR values
     for (int64_t s : intStrides)
       strides.push_back(arith::ConstantIndexOp::create(rewriter, loc, s));
@@ -320,7 +319,7 @@ static Value computeOffsets(VectorTransferOpInterface xferOp,
   for (size_t i = 0; i < vectorRank; ++i) {
     size_t memrefDim = memrefRank - vectorRank + i;
     Value strideValue = strides[memrefDim];
-    auto mulType = llvm::cast<VectorType>(stepVectors[i].getType());
+    auto mulType = dyn_cast<VectorType>(stepVectors[i].getType());
     auto bcastOp =
         vector::BroadcastOp::create(rewriter, loc, mulType, strideValue);
     auto mulOp = arith::MulIOp::create(rewriter, loc, stepVectors[i], bcastOp);
@@ -381,32 +380,21 @@ static Value collapseMemrefTo1D(VectorTransferOpInterface xferOp,
   Location loc = xferOp.getLoc();
 
   Value baseMemref = xferOp.getBase();
-  MemRefType memrefType = llvm::cast<MemRefType>(baseMemref.getType());
+  MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType());
   Type elementType = memrefType.getElementType();
 
   // Compute the total number of elements in the memref
-  int64_t totalElements = 1;
-  bool hasDynamicDim = false;
-  for (int64_t dim : memrefType.getShape()) {
-    if (dim == ShapedType::kDynamic) {
-      hasDynamicDim = true;
-      break;
-    }
-    totalElements *= dim;
-  }
-
   MemRefType flatMemrefType;
-  if (hasDynamicDim) {
-    flatMemrefType = MemRefType::get({ShapedType::kDynamic}, elementType);
-  } else {
+  if (memrefType.hasStaticShape()) {
+    auto totalElements = memrefType.getNumElements();
     flatMemrefType = MemRefType::get({totalElements}, elementType);
+  } else {
+    flatMemrefType = MemRefType::get({ShapedType::kDynamic}, elementType);
   }
 
   SmallVector<ReassociationIndices> reassociation;
-  ReassociationIndices allDims;
-  for (int i = 0; i < memrefType.getRank(); ++i) {
-    allDims.push_back(i);
-  }
+  ReassociationIndices allDims =
+      llvm::to_vector(llvm::seq<int64_t>(0, memrefType.getRank()));
   reassociation.push_back(allDims);
 
   auto collapseOp = memref::CollapseShapeOp::create(
@@ -414,13 +402,24 @@ static Value collapseMemrefTo1D(VectorTransferOpInterface xferOp,
   return collapseOp;
 }
 
-// Create XeGPU gather load operation
-static LogicalResult createLoadGather(vector::TransferReadOp readOp,
-                                      PatternRewriter &rewriter,
-                                      Value flatMemref, Value localOffsets) {
+LogicalResult lowerToRegularLoadOp(vector::TransferReadOp readOp,
+                                   PatternRewriter &rewriter) {
+
   Location loc = readOp.getLoc();
   VectorType vectorType = readOp.getVectorType();
   ArrayRef<int64_t> vectorShape = vectorType.getShape();
+  auto memrefType = dyn_cast<MemRefType>(readOp.getShapedType());
+  if (!memrefType)
+    return rewriter.notifyMatchFailure(readOp, "Expected memref source");
+
+  SmallVector<Value> strides = computeStrides(readOp, rewriter);
+  if (strides.empty())
+    return rewriter.notifyMatchFailure(readOp, "Failed to compute strides");
+
+  Value localOffsets = computeOffsets(readOp, rewriter, strides);
+
+  Value flatMemref = collapseMemrefTo1D(readOp, rewriter);
+
   Value mask = vector::ConstantMaskOp::create(
       rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
       vectorShape);
@@ -430,50 +429,17 @@ static LogicalResult createLoadGather(vector::TransferReadOp readOp,
       /*l1_hint=*/xegpu::CachePolicyAttr{},
       /*l2_hint=*/xegpu::CachePolicyAttr{},
       /*l3_hint=*/xegpu::CachePolicyAttr{});
+
   rewriter.replaceOp(readOp, gatherOp.getResult());
   return success();
 }
 
-// Create XeGPU store scatter operation
-static LogicalResult createStoreScatter(vector::TransferWriteOp writeOp,
-                                        PatternRewriter &rewriter, Value value,
-                                        Value flatMemref, Value localOffsets) {
+LogicalResult lowerToRegularStoreOp(vector::TransferWriteOp writeOp,
+                                    PatternRewriter &rewriter) {
+
   Location loc = writeOp.getLoc();
   VectorType vectorType = writeOp.getVectorType();
   ArrayRef<int64_t> vectorShape = vectorType.getShape();
-  Value mask = vector::ConstantMaskOp::create(
-      rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
-      vectorShape);
-  xegpu::StoreScatterOp::create(rewriter, loc, value, flatMemref, localOffsets,
-                                mask,
-                                /*chunk_size=*/IntegerAttr{},
-                                /*l1_hint=*/xegpu::CachePolicyAttr{},
-                                /*l2_hint=*/xegpu::CachePolicyAttr{},
-                                /*l3_hint=*/xegpu::CachePolicyAttr{});
-  rewriter.eraseOp(writeOp);
-  return success();
-}
-
-LogicalResult lowerTransferReadToLoadOp(vector::TransferReadOp readOp,
-                                        PatternRewriter &rewriter) {
-
-  auto memrefType = dyn_cast<MemRefType>(readOp.getShapedType());
-  if (!memrefType)
-    return rewriter.notifyMatchFailure(readOp, "Expected memref source");
-
-  SmallVector<Value> strides = computeStrides(readOp, rewriter);
-  if (strides.empty())
-    return rewriter.notifyMatchFailure(readOp, "Failed to compute strides");
-
-  Value localOffsets = computeOffsets(readOp, rewriter, strides);
-
-  Value flatMemref = collapseMemrefTo1D(readOp, rewriter);
-
-  return createLoadGather(readOp, rewriter, flatMemref, localOffsets);
-}
-
-LogicalResult lowerTransferWriteToStoreOp(vector::TransferWriteOp writeOp,
-                                          PatternRewriter &rewriter) {
 
   auto memrefType = dyn_cast<MemRefType>(writeOp.getShapedType());
   if (!memrefType)
@@ -485,8 +451,17 @@ LogicalResult lowerTransferWriteToStoreOp(vector::TransferWriteOp writeOp,
 
   Value flatMemref = collapseMemrefTo1D(writeOp, rewriter);
 
-  return createStoreScatter(writeOp, rewriter, writeOp.getVector(), flatMemref,
-                            localOffsets);
+  Value mask = vector::ConstantMaskOp::create(
+      rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
+      vectorShape);
+  xegpu::StoreScatterOp::create(rewriter, loc, writeOp.getVector(), flatMemref,
+                                localOffsets, mask,
+                                /*chunk_size=*/IntegerAttr{},
+                                /*l1_hint=*/xegpu::CachePolicyAttr{},
+                                /*l2_hint=*/xegpu::CachePolicyAttr{},
+                                /*l3_hint=*/xegpu::CachePolicyAttr{});
+  rewriter.eraseOp(writeOp);
+  return success();
 }
 
 struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
@@ -499,14 +474,15 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
     if (failed(transferPreconditions(rewriter, readOp)))
       return failure();
 
-    auto chip = xegpu::getXeGPUChipStr(readOp);
+    // lower to regular load Op if the target HW is not PVC
     // TODO:This check needs to be replaced with proper uArch capability check
-    if ( chip != "pvc" && chip != "bmg") {
+    auto chip = xegpu::getXeGPUChipStr(readOp);
+    if (chip != "pvc" && chip != "bmg") {
       // perform additional checks -
       if (failed(extraCheckForScatteredLoadStore(readOp, rewriter)))
         return failure();
       // calling another function that lower TransferReadOp to regular Loadop
-      return lowerTransferReadToLoadOp(readOp, rewriter);
+      return lowerToRegularLoadOp(readOp, rewriter);
     }
 
     // Perform common data transfer checks.
@@ -569,14 +545,15 @@ struct TransferWriteLowering
     if (failed(transferPreconditions(rewriter, writeOp)))
       return failure();
 
-    auto chip = xegpu::getXeGPUChipStr(writeOp);
+    // lower to regular write Op if the target HW is not PVC
     // TODO:This check needs to be replaced with proper uArch capability check
+    auto chip = xegpu::getXeGPUChipStr(writeOp);
     if (chip != "pvc" && chip != "bmg") {
       // perform additional checks -
       if (failed(extraCheckForScatteredLoadStore(writeOp, rewriter)))
         return failure();
       // calling another function that lower TransferWriteOp to regular StoreOp
-      return lowerTransferWriteToStoreOp(writeOp, rewriter);
+      return lowerToRegularStoreOp(writeOp, rewriter);
     }
 
     // Perform common data transfer checks.

>From cc23a681c477dcdf3b55b74e32d01f559c0aaa3c Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 12 Aug 2025 20:33:15 +0000
Subject: [PATCH 14/17] address comments

---
 .../mlir/Dialect/XeGPU/Utils/XeGPUUtils.h     |  2 +-
 .../VectorToXeGPU/VectorToXeGPU.cpp           | 78 +++++++++----------
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp   |  2 +-
 3 files changed, 38 insertions(+), 44 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index 67f74a4fa2e0e..109899d6b7977 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -123,7 +123,7 @@ Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc,
 void doSCFStructuralTypeConversionWithTensorType(Operation *op,
                                                  TypeConverter converter);
 
-std::optional<std::string> getXeGPUChipStr(Operation *op);
+std::optional<std::string> getChipStr(Operation *op);
 
 } // namespace xegpu
 
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 84871cfccf09d..8b9834610335f 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -152,37 +152,33 @@ createNdDescriptor(PatternRewriter &rewriter, Location loc,
   return ndDesc;
 }
 
-static LogicalResult
-extraCheckForScatteredLoadStore(VectorTransferOpInterface xferOp,
-                                PatternRewriter &rewriter) {
-  // 1. it must be inbound access by checking in_bounds attributes, like
-  // {in_bounds = [false, true]}
-  if (xferOp.hasOutOfBoundsDim())
-    return rewriter.notifyMatchFailure(xferOp,
-                                       "Out-of-bounds access is not supported "
-                                       "for scatter load/store lowering");
-  // 2. if the memref has static shape, its lower rank must exactly match with
-  // vector shape.
-  if (auto memrefType = dyn_cast<MemRefType>(xferOp.getShapedType())) {
-    if (memrefType.hasStaticShape()) {
-      ArrayRef<int64_t> memrefShape = memrefType.getShape();
-      ArrayRef<int64_t> vectorShape = xferOp.getVectorType().getShape();
-      size_t memrefRank = memrefShape.size();
-      size_t vectorRank = vectorShape.size();
-      if (vectorRank > memrefRank)
-        return rewriter.notifyMatchFailure(
-            xferOp, "Vector rank cannot exceed memref rank");
-      // Compare the last vectorRank dimensions of memref with vector shape
-      for (size_t i = 0; i < vectorRank; ++i) {
-        if (vectorShape[i] > memrefShape[memrefRank - vectorRank + i])
-          return rewriter.notifyMatchFailure(
-              xferOp, "Memref lower dimensions must match vector shape");
-      }
-    }
-  }
-  return success();
-}
-
+/// Adjusts the strides of a memref according to a given permutation map for
+/// vector operations.
+///
+/// This function updates the last `vecRank` elements of the `strides` array to
+/// reflect the permutation specified by `permMap`. The permutation is applied
+/// to the innermost dimensions of the memref, corresponding to the vector
+/// shape. This is typically used when lowering vector transfer operations with
+/// permutation maps to memory accesses, ensuring that the memory strides match
+/// the logical permutation of vector dimensions.
+///
+/// Example:
+///   Suppose we have a memref of rank 4 with strides `[s0, s1, s2, s3]` and a
+///   vector of rank 2. If the permutation map swaps the last two dimensions
+///   (e.g., [0, 1] -> [1, 0]), then after calling this function, the last two
+///   strides will be swapped:
+///     Original strides: [s0, s1, s2, s3]
+///     After permutation: [s0, s1, s3, s2]
+///
+/// \param op The operation being rewritten.
+/// \param rewriter The pattern rewriter for IR modifications.
+/// \param memrefType The type of the memref being accessed.
+/// \param permMap The affine permutation map to apply to the vector dimensions.
+/// \param vecType The type of the vector being accessed.
+/// \param strides The array of strides to be adjusted (in-place).
+/// \returns success if the permutation is applied successfully, failure
+/// otherwise.
+///
 static LogicalResult adjustStridesForPermutation(
     Operation *op, PatternRewriter &rewriter, MemRefType memrefType,
     AffineMap permMap, VectorType vecType, SmallVectorImpl<Value> &strides) {
@@ -200,13 +196,11 @@ static LogicalResult adjustStridesForPermutation(
   for (unsigned outIdx = 0; outIdx < vecRank; ++outIdx) {
     AffineExpr expr = permMap.getResult(outIdx);
     auto dimExpr = dyn_cast<AffineDimExpr>(expr);
-    if (!dimExpr) {
-      return rewriter.notifyMatchFailure(op, "Unsupported permutation expr");
-    }
+    assert(dimExpr && "The permutation expr must be affine expression");
     unsigned pos = dimExpr.getPosition();
     // Map permutation to the relevant strides (innermost dims)
-    if (pos < memrefRank - vecRank)
-      return rewriter.notifyMatchFailure(op, "Permutation out of bounds");
+    assert((pos >= (memrefRank - vecRank)) &&
+           "Permuted index must be in the inner dimensions");
 
     // The stride for output dimension outIdx is the stride of input dimension
     // pos
@@ -476,10 +470,10 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
 
     // lower to regular load Op if the target HW is not PVC
     // TODO:This check needs to be replaced with proper uArch capability check
-    auto chip = xegpu::getXeGPUChipStr(readOp);
+    auto chip = xegpu::getChipStr(readOp);
     if (chip != "pvc" && chip != "bmg") {
-      // perform additional checks -
-      if (failed(extraCheckForScatteredLoadStore(readOp, rewriter)))
+      // TODO: add support for OutOfBound access
+      if (readOp.hasOutOfBoundsDim())
         return failure();
       // calling another function that lower TransferReadOp to regular Loadop
       return lowerToRegularLoadOp(readOp, rewriter);
@@ -547,10 +541,10 @@ struct TransferWriteLowering
 
     // lower to regular write Op if the target HW is not PVC
     // TODO:This check needs to be replaced with proper uArch capability check
-    auto chip = xegpu::getXeGPUChipStr(writeOp);
+    auto chip = xegpu::getChipStr(writeOp);
     if (chip != "pvc" && chip != "bmg") {
-      // perform additional checks -
-      if (failed(extraCheckForScatteredLoadStore(writeOp, rewriter)))
+      // TODO: add support for OutOfBound access
+      if (writeOp.hasOutOfBoundsDim())
         return failure();
       // calling another function that lower TransferWriteOp to regular StoreOp
       return lowerToRegularStoreOp(writeOp, rewriter);
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 1f090952894d2..552f3a1c9b6ca 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -407,7 +407,7 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(
   }
 }
 
-std::optional<std::string> xegpu::getXeGPUChipStr(Operation *op) {
+std::optional<std::string> xegpu::getChipStr(Operation *op) {
   auto gpuModuleOp = op->getParentOfType<gpu::GPUModuleOp>();
 
   if (!gpuModuleOp)

>From 99b8035597811f079ea5dfe5dd0eb8801a6df170 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 12 Aug 2025 21:13:45 +0000
Subject: [PATCH 15/17] address comments

---
 .../VectorToXeGPU/VectorToXeGPU.cpp           | 62 ++++++++-----------
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        | 19 +++---
 mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt   |  1 -
 mlir/test/Dialect/XeGPU/invalid.mlir          |  4 +-
 4 files changed, 38 insertions(+), 48 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 8b9834610335f..a8e2fef90c1cd 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -152,41 +152,33 @@ createNdDescriptor(PatternRewriter &rewriter, Location loc,
   return ndDesc;
 }
 
-/// Adjusts the strides of a memref according to a given permutation map for
-/// vector operations.
-///
-/// This function updates the last `vecRank` elements of the `strides` array to
-/// reflect the permutation specified by `permMap`. The permutation is applied
-/// to the innermost dimensions of the memref, corresponding to the vector
-/// shape. This is typically used when lowering vector transfer operations with
-/// permutation maps to memory accesses, ensuring that the memory strides match
-/// the logical permutation of vector dimensions.
-///
-/// Example:
-///   Suppose we have a memref of rank 4 with strides `[s0, s1, s2, s3]` and a
-///   vector of rank 2. If the permutation map swaps the last two dimensions
-///   (e.g., [0, 1] -> [1, 0]), then after calling this function, the last two
-///   strides will be swapped:
-///     Original strides: [s0, s1, s2, s3]
-///     After permutation: [s0, s1, s3, s2]
-///
-/// \param op The operation being rewritten.
-/// \param rewriter The pattern rewriter for IR modifications.
-/// \param memrefType The type of the memref being accessed.
-/// \param permMap The affine permutation map to apply to the vector dimensions.
-/// \param vecType The type of the vector being accessed.
-/// \param strides The array of strides to be adjusted (in-place).
-/// \returns success if the permutation is applied successfully, failure
-/// otherwise.
-///
-static LogicalResult adjustStridesForPermutation(
-    Operation *op, PatternRewriter &rewriter, MemRefType memrefType,
-    AffineMap permMap, VectorType vecType, SmallVectorImpl<Value> &strides) {
+// Adjusts the strides of a memref according to a given permutation map for
+// vector operations.
+//
+// This function updates the last `vecRank` elements of the `strides` array to
+// reflect the permutation specified by `permMap`. The permutation is applied
+// to the innermost dimensions of the memref, corresponding to the vector
+// shape. This is typically used when lowering vector transfer operations with
+// permutation maps to memory accesses, ensuring that the memory strides match
+// the logical permutation of vector dimensions.
+//
+// Example:
+//   Suppose we have a memref of rank 4 with strides `[s0, s1, s2, s3]` and a
+//   vector of rank 2. If the permutation map swaps the last two dimensions
+//   (e.g., [0, 1] -> [1, 0]), then after calling this function, the last two
+//   strides will be swapped:
+//     Original strides: [s0, s1, s2, s3]
+//     After permutation: [s0, s1, s3, s2]
+//
+void adjustStridesForPermutation(Operation *op, PatternRewriter &rewriter,
+                                 MemRefType memrefType, AffineMap permMap,
+                                 VectorType vecType,
+                                 SmallVectorImpl<Value> &strides) {
   unsigned vecRank;
   unsigned memrefRank = memrefType.getRank();
 
   if (permMap.isMinorIdentity())
-    return success();
+    return;
   vecRank = vecType.getRank();
   // Only adjust the last vecRank strides according to the permutation
   ArrayRef<Value> relevantStrides = ArrayRef<Value>(strides).take_back(vecRank);
@@ -209,8 +201,6 @@ static LogicalResult adjustStridesForPermutation(
   // Replace the last vecRank strides with the adjusted ones
   for (unsigned i = 0; i < vecRank; ++i)
     strides[memrefRank - vecRank + i] = adjustedStrides[i];
-
-  return success();
 }
 
 SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
@@ -254,10 +244,8 @@ SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
     strides.append(meta.getStrides().begin(), meta.getStrides().end());
   }
   // Adjust strides according to the permutation map (e.g., for transpose)
-  if (failed(adjustStridesForPermutation(xferOp, rewriter, memrefType, permMap,
-                                         vectorType, strides))) {
-    return {};
-  }
+  adjustStridesForPermutation(xferOp, rewriter, memrefType, permMap, vectorType,
+                              strides);
   return strides;
 }
 
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 8f67339f7cfe8..c5a11ad928e54 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -121,20 +121,23 @@ isValidGatherScatterBufferParams(Type maskTy, VectorType valueTy,
   auto maskShape = getShapeOf(maskTy);
   auto valueShape = getShapeOf(valueTy);
 
-  // a valid shape for SIMT case
   if (valueTy.getRank() == 1) {
     auto maskVecTy = dyn_cast<VectorType>(maskTy);
     if (!maskVecTy)
       return emitError() << "Expecting a vector type mask.";
-    int64_t maskElements = maskVecTy.getNumElements();
+    int64_t maskSize = maskVecTy.getNumElements();
 
     auto valueSize = valueTy.getNumElements();
-    if ((valueSize % chunkSize) != 0)
-      return emitError() << "value elements must be multiple of chunk size "
-                         << chunkSize;
-    if ((valueSize / chunkSize) != maskElements)
-      return emitError()
-             << "Mask should match value except the chunk size dim.";
+    if (chunkSize == 1) {
+      if (maskSize != valueSize)
+        return emitError()
+               << "Mask should match value except the chunk size dim.";
+    } else {
+      if (chunkSize != valueSize)
+        return emitError() << "value elements must match chunk size "
+                           << chunkSize;
+    }
+
     return success();
   }
 
diff --git a/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt
index 23c26875476b6..d9bf4a1461c27 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt
@@ -8,7 +8,6 @@ add_mlir_dialect_library(MLIRXeGPUUtils
   MLIRIR
   MLIRSCFTransforms
   MLIRGPUDialect
-  MLIRLLVMDialect
   MLIRXeVMDialect
   MLIRXeGPUDialect
   )
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 9a1a3de9e233a..05e0688639e43 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -407,7 +407,7 @@ func.func @load_gather_offset_sg(%src: memref<?xf16>) {
 func.func @load_gather_offset_wi(%src: ui64) {
   %mask = arith.constant dense<1>: vector<1xi1>
   %offsets = arith.constant dense<[0]> : vector<1xindex>
-  // expected-error at +1 {{value elements must be multiple of chunk size}}
+  // expected-error at +1 {{value elements must match chunk size}}
   %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : ui64,  vector<1xindex>, vector<1xi1> -> vector<3xf32>
   return
 }
@@ -438,7 +438,7 @@ func.func @store_scatter_offset_wi_2(%src: memref<4x4xf16>) {
 func.func @load_gather_offset_wi_2(%src: ui64) {
   %mask = arith.constant dense<1>: vector<1xi1>
   %offsets = arith.constant dense<[0]> : vector<1xindex>
-  // expected-error at +1 {{value elements must be multiple of chunk size}}
+  // expected-error at +1 {{value elements must match chunk size}}
   %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : ui64,  vector<1xindex>, vector<1xi1> -> vector<3xf16>
   return
 }

>From aea47850c8deb018ede62b183a3d6eecd49c07a9 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 13 Aug 2025 21:38:29 +0000
Subject: [PATCH 16/17] polish

---
 .../mlir/Dialect/XeGPU/Utils/XeGPUUtils.h     |  3 +
 .../VectorToXeGPU/VectorToXeGPU.cpp           | 81 ++++++++-----------
 mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp   |  2 +-
 3 files changed, 38 insertions(+), 48 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
index 109899d6b7977..db8608c6d20b8 100644
--- a/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
+++ b/mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h
@@ -123,6 +123,9 @@ Value createVectorWithShapeFromValues(OpBuilder &builder, Location loc,
 void doSCFStructuralTypeConversionWithTensorType(Operation *op,
                                                  TypeConverter converter);
 
+/// Retrieves the chip string from the XeVM target attribute of the parent
+/// GPU module operation. Returns the chip identifier if found, or nullopt
+/// if no GPU module parent or XeVM target attribute exists.
 std::optional<std::string> getChipStr(Operation *op);
 
 } // namespace xegpu
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index a8e2fef90c1cd..d6400df19a320 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -14,6 +14,7 @@
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
@@ -170,39 +171,22 @@ createNdDescriptor(PatternRewriter &rewriter, Location loc,
 //     Original strides: [s0, s1, s2, s3]
 //     After permutation: [s0, s1, s3, s2]
 //
-void adjustStridesForPermutation(Operation *op, PatternRewriter &rewriter,
-                                 MemRefType memrefType, AffineMap permMap,
-                                 VectorType vecType,
-                                 SmallVectorImpl<Value> &strides) {
-  unsigned vecRank;
-  unsigned memrefRank = memrefType.getRank();
-
-  if (permMap.isMinorIdentity())
-    return;
-  vecRank = vecType.getRank();
-  // Only adjust the last vecRank strides according to the permutation
-  ArrayRef<Value> relevantStrides = ArrayRef<Value>(strides).take_back(vecRank);
-  SmallVector<Value> adjustedStrides(vecRank);
-  // For each output dimension in the permutation map, find which input dim it
-  // refers to, and assign the corresponding stride.
-  for (unsigned outIdx = 0; outIdx < vecRank; ++outIdx) {
-    AffineExpr expr = permMap.getResult(outIdx);
-    auto dimExpr = dyn_cast<AffineDimExpr>(expr);
-    assert(dimExpr && "The permutation expr must be affine expression");
-    unsigned pos = dimExpr.getPosition();
-    // Map permutation to the relevant strides (innermost dims)
-    assert((pos >= (memrefRank - vecRank)) &&
-           "Permuted index must be in the inner dimensions");
-
-    // The stride for output dimension outIdx is the stride of input dimension
-    // pos
-    adjustedStrides[outIdx] = relevantStrides[pos - (memrefRank - vecRank)];
-  }
-  // Replace the last vecRank strides with the adjusted ones
-  for (unsigned i = 0; i < vecRank; ++i)
-    strides[memrefRank - vecRank + i] = adjustedStrides[i];
+static void adjustStridesForPermutation(Operation *op,
+                                        PatternRewriter &rewriter,
+                                        MemRefType memrefType,
+                                        AffineMap permMap, VectorType vecType,
+                                        SmallVectorImpl<Value> &strides) {
+
+  AffineMap invMap = inverseAndBroadcastProjectedPermutation(permMap);
+  SmallVector<unsigned> perms;
+  invMap.isPermutationOfMinorIdentityWithBroadcasting(perms);
+  SmallVector<int64_t> perms64(perms.begin(), perms.end());
+  strides = applyPermutation(strides, perms64);
 }
 
+// Computes memory strides for vector transfer operations, handling both
+// static and dynamic memrefs while applying permutation transformations
+// for XeGPU lowering.
 SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
                                   PatternRewriter &rewriter) {
   SmallVector<Value> strides;
@@ -232,12 +216,12 @@ SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
     resultTypes.push_back(MemRefType::get(
         {}, memrefType.getElementType())); // base memref (unranked)
     resultTypes.push_back(indexType);      // offset
-    for (unsigned i = 0; i < rank; ++i) {
+
+    for (unsigned i = 0; i < rank; ++i)
       resultTypes.push_back(indexType); // strides
-    }
-    for (unsigned i = 0; i < rank; ++i) {
+
+    for (unsigned i = 0; i < rank; ++i)
       resultTypes.push_back(indexType); // sizes
-    }
 
     auto meta = memref::ExtractStridedMetadataOp::create(
         rewriter, loc, resultTypes, baseMemref);
@@ -288,11 +272,12 @@ static Value computeOffsets(VectorTransferOpInterface xferOp,
 
   // Create vector.step operations for each dimension
   SmallVector<Value> stepVectors;
-  for (int64_t dim : vectorShape) {
+  llvm::map_to_vector(vectorShape, [&](int64_t dim) {
     auto stepType = VectorType::get({dim}, rewriter.getIndexType());
     auto stepOp = vector::StepOp::create(rewriter, loc, stepType);
     stepVectors.push_back(stepOp);
-  }
+    return stepOp;
+  });
 
   // Multiply step vectors by corresponding strides
   size_t memrefRank = strides.size();
@@ -384,8 +369,8 @@ static Value collapseMemrefTo1D(VectorTransferOpInterface xferOp,
   return collapseOp;
 }
 
-LogicalResult lowerToRegularLoadOp(vector::TransferReadOp readOp,
-                                   PatternRewriter &rewriter) {
+static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
+                                            PatternRewriter &rewriter) {
 
   Location loc = readOp.getLoc();
   VectorType vectorType = readOp.getVectorType();
@@ -416,8 +401,8 @@ LogicalResult lowerToRegularLoadOp(vector::TransferReadOp readOp,
   return success();
 }
 
-LogicalResult lowerToRegularStoreOp(vector::TransferWriteOp writeOp,
-                                    PatternRewriter &rewriter) {
+static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp,
+                                             PatternRewriter &rewriter) {
 
   Location loc = writeOp.getLoc();
   VectorType vectorType = writeOp.getVectorType();
@@ -456,15 +441,16 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
     if (failed(transferPreconditions(rewriter, readOp)))
       return failure();
 
-    // lower to regular load Op if the target HW is not PVC
     // TODO:This check needs to be replaced with proper uArch capability check
     auto chip = xegpu::getChipStr(readOp);
     if (chip != "pvc" && chip != "bmg") {
+      // lower to scattered load Op if the target HW doesn't have 2d block load
+      // support
       // TODO: add support for OutOfBound access
       if (readOp.hasOutOfBoundsDim())
         return failure();
-      // calling another function that lower TransferReadOp to regular Loadop
-      return lowerToRegularLoadOp(readOp, rewriter);
+      // lower TransferReadOp to scattered Loadop
+      return lowerToScatteredLoadOp(readOp, rewriter);
     }
 
     // Perform common data transfer checks.
@@ -527,15 +513,16 @@ struct TransferWriteLowering
     if (failed(transferPreconditions(rewriter, writeOp)))
       return failure();
 
-    // lower to regular write Op if the target HW is not PVC
     // TODO:This check needs to be replaced with proper uArch capability check
     auto chip = xegpu::getChipStr(writeOp);
     if (chip != "pvc" && chip != "bmg") {
+      // lower to scattered load Op if the target HW doesn't have 2d block load
+      // support
       // TODO: add support for OutOfBound access
       if (writeOp.hasOutOfBoundsDim())
         return failure();
-      // calling another function that lower TransferWriteOp to regular StoreOp
-      return lowerToRegularStoreOp(writeOp, rewriter);
+      // lower TransferWriteOp to scattered StoreOp
+      return lowerToScatteredStoreOp(writeOp, rewriter);
     }
 
     // Perform common data transfer checks.
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 552f3a1c9b6ca..19eedbac0f76b 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -423,4 +423,4 @@ std::optional<std::string> xegpu::getChipStr(Operation *op) {
   }
 
   return std::nullopt;
-}
\ No newline at end of file
+}

>From 762f7c38e3fa2d7565c1518dfc33ee7bee27f21c Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 13 Aug 2025 22:01:42 +0000
Subject: [PATCH 17/17] remove unnecessary comments

---
 mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp | 2 --
 1 file changed, 2 deletions(-)

diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index d6400df19a320..db0ee8432a2ca 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -449,7 +449,6 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
       // TODO: add support for OutOfBound access
       if (readOp.hasOutOfBoundsDim())
         return failure();
-      // lower TransferReadOp to scattered Loadop
       return lowerToScatteredLoadOp(readOp, rewriter);
     }
 
@@ -521,7 +520,6 @@ struct TransferWriteLowering
       // TODO: add support for OutOfBound access
       if (writeOp.hasOutOfBoundsDim())
         return failure();
-      // lower TransferWriteOp to scattered StoreOp
       return lowerToScatteredStoreOp(writeOp, rewriter);
     }
 



More information about the Mlir-commits mailing list