[Mlir-commits] [mlir] [MLIR][XeGPU] XeVM lowering support for load_matrix/store_matrix (PR #162780)

Jianhui Li llvmlistbot at llvm.org
Wed Oct 15 16:36:57 PDT 2025


https://github.com/Jianhui-Li updated https://github.com/llvm/llvm-project/pull/162780

>From 4c58d3d6a627f23425528668ffff92bcca8f1461 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 7 Oct 2025 22:08:30 +0000
Subject: [PATCH 01/12] pass basic  lowering test

---
 .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td       |  22 ++
 .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td |   6 +-
 .../mlir/Dialect/XeGPU/IR/XeGPUTypes.td       |  50 +++-
 .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp    | 124 ++++++++++
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    | 216 ++++++++++++++++++
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        |  20 +-
 mlir/test/Conversion/XeGPUToXeVM/dpas.mlir    |   6 +-
 .../XeGPUToXeVM/loadstore_matrix.mlir         |  40 ++++
 8 files changed, 466 insertions(+), 18 deletions(-)
 create mode 100644 mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 5695d5d515d7f..601e966b49890 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -716,8 +716,30 @@ def XeGPU_MemLayoutAttr : XeGPUAttr<"MemLayout", "mem_layout"> {
       return getAttrs().getAs<ArrayAttr>("stride");
     }
 
+    ArrayAttr getBlockAttr() {
+      return getAttrs().getAs<ArrayAttr>("block");
+    }
+
   }];
 
 }
 
+def RowOriented : I32EnumAttrCase<"ROW", 0, "row">;
+def ColOriented : I32EnumAttrCase<"COL", 1, "col">;
+def MatrixAccessDirection : 
+    I32EnumAttr<"MatrixAccessDirection", 
+    "Matrix elements/vectors can have row or column direction", [
+    RowOriented, ColOriented
+]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::xegpu";
+}
+def MatrixAccessDirectionAttr : 
+    EnumAttr<XeGPU_Dialect, 
+    MatrixAccessDirection, 
+    "matrix_access_direction">{
+  let summary = [{Describe the direction of memory access for load_matrix and store_matrix.}];
+  let assemblyFormat = "`<` $value `>`";
+}
+
 #endif // MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 73f9061f5debe..32d21bae8cd34 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -1298,8 +1298,7 @@ def XeGPU_CreateMemDescOp: XeGPU_Op<"create_mem_desc", [Pure,
 }
 
 def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
-                              AllElementTypesMatch<["mem_desc", "res"]>,
-                              AllRanksMatch<["mem_desc", "res"]>]>  {
+                              AllElementTypesMatch<["mem_desc", "res"]>]>  {
   let arguments = (ins XeGPU_MemDesc:$mem_desc,
     Variadic<Index>: $offsets,
     DenseI64ArrayAttr: $const_offsets,
@@ -1344,8 +1343,7 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
 }
 
 def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
-                              AllElementTypesMatch<["mem_desc", "data"]>,
-                              AllRanksMatch<["mem_desc", "data"]>]> {
+                              AllElementTypesMatch<["mem_desc", "data"]>]> {
   let arguments = (ins
     XeGPU_ValueType:$data,
     XeGPU_MemDesc:$mem_desc,
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 84902b2039643..c261fbb576642 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -237,7 +237,7 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
       return MemDescType::get(getContext(), shape.value_or(getShape()), elementType, getMemLayout());
     }
 
-    ArrayAttr getStrides() {
+    ArrayAttr getStridesAttr() {
       auto layout = getMemLayout();
       if (layout && layout.hasAttr("stride")) {
         return layout.getStrides();
@@ -250,6 +250,54 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
       Builder builder(getContext());
       return builder.getI64ArrayAttr(defaultStrides);
     }
+
+    /// Heuristic to determine if the MemDesc uses column-major layout,
+    /// based on the rank and the value of the first stride dimension.
+    bool isColMajor() {
+      auto dim0 = dyn_cast<IntegerAttr>(getStridesAttr()[0]);
+      return getRank() == 2 && dim0 && dim0.getInt() == 1;
+    }
+
+    // get the Blocking shape for a MemDescType, Which is represented
+    // as an attribute in MemDescType. By default it is the shape
+    // of the mdescTy
+    SmallVector<int64_t> getBlockSize() {
+      SmallVector<int64_t> size(getShape());
+      MemLayoutAttr layout = getMemLayout();
+      if (layout && layout.hasAttr("block")) {
+        ArrayAttr attr = layout.getBlockAttr();
+        size.clear();
+        llvm::for_each(attr, [&](Attribute elem) {
+          if (auto intElem = dyn_cast<IntegerAttr>(elem))
+            size.push_back(intElem.getInt());
+        });
+      }
+      return size;
+    }
+
+    // Get strides as vector of integer. 
+    // If it contains block attribute, the strides are blocked strides.
+    //
+    // The blocking is applied against the original matrix shape
+    // so that the linear offset is not impacted by the subview.
+    //
+    // It first computes the original matrix shape using the stride info,
+    // then computes the number of blocks in each dimension of original shape,
+    // then compute the outer block shape and stride,
+    // then combines the inner and outer block shape and stride
+    // e.g. for mem_desc<32x256xf16, @block=[16, 8], @strides=[1, 32]>
+    // its memory layout tuple is ([2,32,16,8],[128,256,1,16])
+    // for  mem_desc<256x32xf16, @block=[8, 16]> with default @stride[32, 1]
+    // its memory layout tuple is ([32,2,8,16],[256,128,16,1])
+    SmallVector<int64_t> getStrides();
+    
+    /// Generates instructions to compute the linearize offset 
+    //  if the memory descriptor is blocked, it returns linearize offset based on the blocked layout
+    //  the strides of memory descriptor is always considered regardless of blocked or not
+    Value getLinearOffsets(OpBuilder &builder,
+               Location loc, ArrayRef<OpFoldResult> offsets);
+
+
   }];
 
   let hasCustomAssemblyFormat = true;
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 9ead1d89069d6..666df293bb8be 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -32,6 +32,8 @@
 
 #include <numeric>
 
+#define DEBUG_TYPE "xegpu-to-xevm"
+
 namespace mlir {
 #define GEN_PASS_DEF_CONVERTXEGPUTOXEVMPASS
 #include "mlir/Conversion/Passes.h.inc"
@@ -60,6 +62,9 @@ static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
     return static_cast<int>(xevm::AddrSpace::GLOBAL);
   case xegpu::MemorySpace::SLM:
     return static_cast<int>(xevm::AddrSpace::SHARED);
+  default:
+    llvm_unreachable("Unknown XeGPU memory space");
+    return static_cast<int>(xevm::AddrSpace::GLOBAL);
   }
 }
 
@@ -366,6 +371,7 @@ static Value addOffset(ConversionPatternRewriter &rewriter, Location loc,
                        Value baseAddr, Value offset, int64_t elemByteSize) {
   Value byteSize = arith::ConstantIntOp::create(
       rewriter, loc, rewriter.getI64Type(), elemByteSize);
+  offset = arith::IndexCastUIOp::create(rewriter, loc, rewriter.getI64Type(), offset);
   Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize);
   Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset);
   return newAddr;
@@ -503,6 +509,113 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
   }
 };
 
+// Lower xegpu::CreateMemDescOp to memref::ViewOp. Since SLM access instructions
+// on Xe2 and Xe3 operate on 32-bit or 64-bit units, all data types smaller than
+// 32 bits will be converted to 32 bits.
+class CreateMemDescOpPattern final
+  : public OpConversionPattern<xegpu::CreateMemDescOp> {
+public:
+  using OpConversionPattern<xegpu::CreateMemDescOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
+          ConversionPatternRewriter &rewriter) const override {
+  // DEBUG: Print operation and types
+  LLVM_DEBUG(llvm::dbgs() << "[XeGPUToXeVM] Lowering CreateMemDescOp: " << op << "\n");
+  TypedValue<MemRefType> src = op.getSource();
+  auto resTy = cast<xegpu::MemDescType>(op.getResult().getType());
+ 
+  // Create the result MemRefType with the same shape, element type, and memory space
+  auto newResTy = getTypeConverter()->convertType<MemRefType>(resTy);
+  
+  LLVM_DEBUG(llvm::dbgs() << "[XeGPUToXeVM] Source MemRefType: " << src.getType() << "\n");
+  LLVM_DEBUG(llvm::dbgs() << "[XeGPUToXeVM] Result MemDescType: " << resTy << "\n");
+  LLVM_DEBUG(llvm::dbgs() << "[XeGPUToXeVM] Converted MemRefType: " << newResTy << "\n");
+  Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0);
+  auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy, Value(src), zero,
+                        ValueRange());
+  rewriter.replaceOp(op, viewOp);
+  LLVM_DEBUG(llvm::dbgs() << "[XeGPUToXeVM] Replaced CreateMemDescOp with memref::ViewOp\n");
+  return success();
+  }
+};
+
+class MemDescSubviewOpPattern final
+    : public OpConversionPattern<xegpu::MemDescSubviewOp> {
+public:
+  using OpConversionPattern<xegpu::MemDescSubviewOp>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(xegpu::MemDescSubviewOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    return rewriter.notifyMatchFailure(
+        op, "MemDescSubviewOp are not supported on Xe2/Xe3 architecture.");
+  }
+};
+
+
+template <typename OpType,
+          typename = std::enable_if_t<llvm::is_one_of<
+              OpType, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
+class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
+  using OpConversionPattern<OpType>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+
+    SmallVector<OpFoldResult> offsets = op.getMixedOffsets();
+    if (offsets.empty())
+      return rewriter.notifyMatchFailure(op, "Expected offset to be provided.");
+
+    auto loc = op.getLoc();
+    auto ctxt = rewriter.getContext();
+    Value basePtrStruct = adaptor.getMemDesc();
+    Value mdescVal = op.getMemDesc();
+    // Load result or Store value Type can be vector or scalar.
+    Value data;
+    if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>)
+      data = op.getResult();
+    else
+      data = adaptor.getData();
+    VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType());
+
+    int64_t elemBitWidth = valOrResVecTy.getElementType().getIntOrFloatBitWidth();
+    // Element type must be multiple of 8 bits.
+    if (elemBitWidth % 8 != 0)
+      return rewriter.notifyMatchFailure(
+          op, "Expected element type bit width to be multiple of 8.");
+    int64_t elemByteSize = elemBitWidth / 8;
+
+    // Default memory space is SLM.
+    LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
+        ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::SLM));
+
+    auto mdescTy = cast<xegpu::MemDescType>(mdescVal.getType());
+    
+    Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, basePtrStruct);
+
+    // Convert base pointer (ptr) to i64
+    Value basePtrI64 = arith::IndexCastUIOp::create(rewriter, loc, rewriter.getI64Type(), basePtrLLVM);
+
+    Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets);
+    basePtrI64 = addOffset(rewriter, loc, basePtrI64, linearOffset, elemByteSize);
+
+    // convert base pointer (i64) to LLVM pointer type
+    basePtrLLVM =
+        LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
+
+    if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
+
+      Value loadOp =
+          LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM);
+      rewriter.replaceOp(op, loadOp);
+    } else {
+      auto storeOp =
+          LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM);
+      rewriter.eraseOp(op);
+    }
+    return success();
+  }
+};
+
 class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
   using OpConversionPattern::OpConversionPattern;
   LogicalResult
@@ -785,6 +898,13 @@ struct ConvertXeGPUToXeVMPass
       auto i32Type = IntegerType::get(&getContext(), 32);
       return VectorType::get(8, i32Type);
     });
+    // Convert MemDescType into flattened MemRefType for SLM
+    typeConverter.addConversion([&](xegpu::MemDescType type) -> Type {
+      Type elemTy = type.getElementType();
+      int numElems = type.getNumElements();
+      return MemRefType::get(numElems, elemTy, AffineMap(), 3);
+    });
+
     typeConverter.addConversion([&](MemRefType type) -> Type {
       // Convert MemRefType to i64 type.
       return IntegerType::get(&getContext(), 64);
@@ -919,6 +1039,10 @@ void mlir::populateXeGPUToXeVMConversionPatterns(
                LoadStoreToXeVMPattern<xegpu::LoadGatherOp>,
                LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>(
       typeConverter, patterns.getContext());
+  patterns.add<LoadStoreMatrixToXeVMPattern<xegpu::LoadMatrixOp>,
+               LoadStoreMatrixToXeVMPattern<xegpu::StoreMatrixOp>,
+               CreateMemDescOpPattern, MemDescSubviewOpPattern>(
+      typeConverter, patterns.getContext());
   patterns.add<FenceToXeVMPattern, DpasToXeVMPattern>(typeConverter,
                                                       patterns.getContext());
 }
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 94c5509fd7c29..c64699c12cf4a 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -37,6 +37,8 @@ void XeGPUDialect::initialize() {
       >();
 }
 
+#define DEBUG_TYPE "xegpu"
+
 /// Generates instructions to compute offsets for a subgroup identified by
 /// its multidimensional indices (sgId), using the specified subgroup layout
 /// (sgLayout), subgroup data dimensions (sizePerSg), and the overall data
@@ -726,6 +728,220 @@ void MemLayoutAttr::print(AsmPrinter &printer) const {
   }
   printer << ">";
 }
+// a helper utility to perform binary operation on OpFoldResult.
+// If both a and b are attributes, it will simply return the result.
+// Otherwise, the corresponding arith op will be generated, and an
+// contant op will be created if one of them is an attribute.
+template <typename ArithOp>
+OpFoldResult genBinOp(OpFoldResult a, OpFoldResult b, Location loc,
+                      OpBuilder &builder) {
+  auto aVal = getValueOrCreateConstantIndexOp(builder, loc, a);
+  auto bVal = getValueOrCreateConstantIndexOp(builder, loc, b);
+  return builder.create<ArithOp>(loc, aVal, bVal).getResult();
+}
+
+// a helper utility to perform division operation on OpFoldResult and int64_t.
+#define div(a, b)                                                              \
+  genBinOp<arith::DivSIOp>(a, builder.getIndexAttr(b), loc, builder)
+
+// a helper utility to perform reminder operation on OpFoldResult and int64_t.
+#define rem(a, b)                                                              \
+  genBinOp<arith::RemSIOp>(a, builder.getIndexAttr(b), loc, builder)
+
+// a helper utility to perform multiply operation on OpFoldResult and int64_t.
+#define mul(a, b)                                                              \
+  genBinOp<arith::MulIOp>(a, builder.getIndexAttr(b), loc, builder)
+
+// a helper utility to perform addition operation on two OpFoldResult.
+#define add(a, b) genBinOp<arith::AddIOp>(a, b, loc, builder)
+
+// block the given offsets according to the block shape
+// say the original offset is [y, x], and the block shape is [By, Bx],
+// then the blocked offset is [y/By, x/Bx, y%By, x%Bx]
+SmallVector<OpFoldResult> getBlockedOffsets(OpBuilder &builder, Location loc,
+                                            ArrayRef<OpFoldResult> offsets,
+                                            ArrayRef<int64_t> blockShape) {
+
+  assert(offsets.size() == blockShape.size() &&
+         "offsets and blockShape must have the same size");
+  SmallVector<OpFoldResult> blockedOffsets;
+  SmallVector<OpFoldResult> divs, rems;
+
+  for (auto [offset, block] : llvm::zip(offsets, blockShape)) {
+    divs.push_back(div(offset, block));
+    rems.push_back(rem(offset, block));
+  }
+  blockedOffsets.append(divs.begin(), divs.end());
+  blockedOffsets.append(rems.begin(), rems.end());
+
+  return blockedOffsets;
+}
+
+// Get strides as vector of integer for MemDesc. 
+SmallVector<int64_t> MemDescType::getStrides() {
+
+      SmallVector<int64_t> matrixShape(getShape().begin(),
+                      getShape().end());
+
+      ArrayAttr strideAttr = getStridesAttr();
+      SmallVector<int64_t> strides;
+      for (Attribute attr : strideAttr.getValue()) {
+      strides.push_back(cast<IntegerAttr>(attr).getInt());
+      }
+
+      llvm::dbgs() << "DEBUG: matrixShape = [";
+      for (size_t i = 0; i < matrixShape.size(); ++i) {
+      llvm::dbgs() << matrixShape[i];
+      if (i < matrixShape.size() - 1) llvm::dbgs() << ", ";
+      }
+      llvm::dbgs() << "]\n";
+
+      llvm::dbgs() << "DEBUG: strides = [";
+      for (size_t i = 0; i < strides.size(); ++i) {
+      llvm::dbgs() << strides[i];
+      if (i < strides.size() - 1) llvm::dbgs() << ", ";
+      }
+      llvm::dbgs() << "]\n";
+
+      SmallVector<int64_t> innerBlkShape = getBlockSize();
+      llvm::dbgs() << "DEBUG: innerBlkShape = [";
+      for (size_t i = 0; i < innerBlkShape.size(); ++i) {
+      llvm::dbgs() << innerBlkShape[i];
+      if (i < innerBlkShape.size() - 1) llvm::dbgs() << ", ";
+      }
+      llvm::dbgs() << "]\n";
+
+      if (innerBlkShape.empty())
+      return strides;
+
+      SmallVector<int, 4> perm = llvm::to_vector<4>(
+                    llvm::seq<int>(0, strides.size()));
+      llvm::sort(perm, [&](int a, int b) { return strides[a] < strides[b]; });
+
+      llvm::dbgs() << "DEBUG: perm = [";
+      for (size_t i = 0; i < perm.size(); ++i) {
+      llvm::dbgs() << perm[i];
+      if (i < perm.size() - 1) llvm::dbgs() << ", ";
+      }
+      llvm::dbgs() << "]\n";
+
+      assert(strides[perm[0]] == 1 && "inner most dim must have stride 1");
+
+      SmallVector<int64_t> innerBlkStride = computeStrides(innerBlkShape);
+      
+      llvm::dbgs() << "DEBUG: innerBlkStride = [";
+      for (size_t i = 0; i < innerBlkStride.size(); ++i) {
+      llvm::dbgs() << innerBlkStride[i];
+      if (i < innerBlkStride.size() - 1) llvm::dbgs() << ", ";
+      }
+      llvm::dbgs() << "]\n";
+
+      // compute the original matrix shape using the stride info
+      // and compute the number of blocks in each dimension
+      // The shape of highest dim can't be derived from stride info,
+      // but doesn't impact the stride computation for blocked layout.
+      SmallVector<int64_t> matrixShapeOrig(matrixShape.size());
+      SmallVector<int64_t> BlkShapeOrig(matrixShape.size());
+      for (size_t i = 0; i < perm.size() - 1; ++i) {
+      matrixShapeOrig[perm[i]] = strides[perm[i + 1]] / strides[perm[i]];
+      BlkShapeOrig[perm[i]] = matrixShapeOrig[perm[i]] / innerBlkShape[perm[i]];
+      }
+
+      llvm::dbgs() << "DEBUG: matrixShapeOrig = [";
+      for (size_t i = 0; i < matrixShapeOrig.size(); ++i) {
+      llvm::dbgs() << matrixShapeOrig[i];
+      if (i < matrixShapeOrig.size() - 1) llvm::dbgs() << ", ";
+      }
+      llvm::dbgs() << "]\n";
+
+      llvm::dbgs() << "DEBUG: BlkShapeOrig = [";
+      for (size_t i = 0; i < BlkShapeOrig.size(); ++i) {
+      llvm::dbgs() << BlkShapeOrig[i];
+      if (i < BlkShapeOrig.size() - 1) llvm::dbgs() << ", ";
+      }
+      llvm::dbgs() << "]\n";
+
+      int64_t innerBlkSize = 1;
+      for (auto s : innerBlkShape)
+      innerBlkSize *= s;
+
+      llvm::dbgs() << "DEBUG: innerBlkSize = " << innerBlkSize << "\n";
+
+      SmallVector<int64_t> outerBlkStride(matrixShape.size());
+      outerBlkStride[perm[0]] = innerBlkSize;
+      for (size_t i = 0; i < perm.size() - 1; ++i) {
+      outerBlkStride[perm[i + 1]] =
+        outerBlkStride[perm[i]] * BlkShapeOrig[perm[i]];
+      }
+
+      llvm::dbgs() << "DEBUG: outerBlkStride = [";
+      for (size_t i = 0; i < outerBlkStride.size(); ++i) {
+      llvm::dbgs() << outerBlkStride[i];
+      if (i < outerBlkStride.size() - 1) llvm::dbgs() << ", ";
+      }
+      llvm::dbgs() << "]\n";
+
+      // combine the inner and outer strides
+      SmallVector<int64_t> blockedStrides;
+      blockedStrides.append(outerBlkStride.begin(), outerBlkStride.end());
+      blockedStrides.append(innerBlkStride.begin(), innerBlkStride.end());
+
+      llvm::dbgs() << "DEBUG: blockedStrides = [";
+      for (size_t i = 0; i < blockedStrides.size(); ++i) {
+      llvm::dbgs() << blockedStrides[i];
+      if (i < blockedStrides.size() - 1) llvm::dbgs() << ", ";
+      }
+      llvm::dbgs() << "]\n";
+
+      return blockedStrides;
+    }
+
+// Calculate the linear offset using the blocked offsets and stride
+Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc,
+                  ArrayRef<OpFoldResult> offsets) {
+
+  SmallVector<int64_t> blockShape = getBlockSize();
+  SmallVector<int64_t> strides = getStrides();
+  
+  LLVM_DEBUG(llvm::dbgs() << "getLinearOffsets: blockShape=[";
+       llvm::interleaveComma(blockShape, llvm::dbgs());
+       llvm::dbgs() << "], strides=[";
+       llvm::interleaveComma(strides, llvm::dbgs());
+       llvm::dbgs() << "]\n");
+  
+  if (!blockShape.empty()) {
+  assert(offsets.size() == blockShape.size() &&
+       "offsets and blockShape must have the same size");
+  // say the original offset is [y, x], and the block shape is [By, Bx],
+  // then the blocked offset is [y/By, x/Bx, y%By, x%Bx]
+  SmallVector<OpFoldResult> blockedOffsets;
+  SmallVector<OpFoldResult> divs, rems;
+
+  for (auto [offset, block] : llvm::zip(offsets, blockShape)) {
+    divs.push_back(div(offset, block));
+    rems.push_back(rem(offset, block));
+  }
+  blockedOffsets.append(divs.begin(), divs.end());
+  blockedOffsets.append(rems.begin(), rems.end());
+
+  offsets = blockedOffsets;
+  LLVM_DEBUG(llvm::dbgs() << "getLinearOffsets: blocked offsets size=" 
+              << offsets.size() << "\n");
+  }
+
+  // Start with initial value as matrix descriptor's base offset.
+  Value linearOffset = arith::ConstantIndexOp::create(builder, loc, 0);
+  for (size_t i = 0; i < offsets.size(); ++i) {
+  OpFoldResult mulResult = mul(offsets[i], strides[i]);
+  Value mulVal = getValueOrCreateConstantIndexOp(builder, loc, mulResult);
+  linearOffset = arith::AddIOp::create(builder, loc, mulVal, linearOffset);
+  }
+
+  LLVM_DEBUG(llvm::dbgs() << "getLinearOffsets: final linearOffset=" 
+              << linearOffset << "\n");
+
+  return linearOffset;
+}
 
 } // namespace xegpu
 } // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 81b5788d0b9b4..23e487787652d 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -1062,9 +1062,12 @@ LogicalResult LoadMatrixOp::verify() {
 
   ArrayRef<int64_t> valueShape = resTy.getShape();
   ArrayRef<int64_t> mdescShape = mdescTy.getShape();
-  if (llvm::any_of(llvm::zip_equal(valueShape, mdescShape),
-                   [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
-    return emitOpError("result shape must not exceed mem_desc shape.");
+
+  if (valueShape.size() != 1) {
+    if (llvm::any_of(llvm::zip_equal(valueShape, mdescShape),
+                     [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
+      return emitOpError("result shape must not exceed mem_desc shape.");
+  }
   return success();
 }
 
@@ -1092,10 +1095,11 @@ LogicalResult StoreMatrixOp::verify() {
 
   ArrayRef<int64_t> dataShape = dataTy.getShape();
   ArrayRef<int64_t> mdescShape = mdescTy.getShape();
-  if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
-                   [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
-    return emitOpError("data shape must not exceed mem_desc shape.");
-
+  if (dataShape.size() != 1) {
+    if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
+                    [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
+      return emitOpError("data shape must not exceed mem_desc shape.");
+  }
   return success();
 }
 
@@ -1127,7 +1131,7 @@ LogicalResult MemDescSubviewOp::verify() {
           [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
     return emitOpError("result shape must not exceed source shape.");
 
-  if (srcTy.getStrides() != resTy.getStrides())
+  if (srcTy.getStridesAttr() != resTy.getStridesAttr())
     return emitOpError("result must inherit the source strides.");
 
   return success();
diff --git a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
index e6f22f0a9acbb..bbf313bf4fb60 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
@@ -1,10 +1,6 @@
 // RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
 
-#sg_map_a_f16 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
-#sg_map_b_f16 = #xegpu.layout<lane_layout = [1, 16], lane_data = [2, 1]>
-#sg_map_c_f32 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
-
-gpu.module @load_store_check {
+gpu.module @test_kernel {
     // CHECK-LABEL: func.func @dpas(
     // CHECK-SAME: %[[ARG0:.*]]: vector<8xf16>, %[[ARG1:.*]]: vector<16xf16>, %[[ARG2:.*]]: vector<8xf32>
     func.func @dpas(%a_loaded: vector<8xf16>, %b_loaded: vector<16xf16>, %c_loaded: vector<8xf32>) -> vector<8xf32> {
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
new file mode 100644
index 0000000000000..30d6274c9dccf
--- /dev/null
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
@@ -0,0 +1,40 @@
+// RUN: mlir-opt  -split-input-file -convert-xegpu-to-xevm --cse --canonicalize %s | FileCheck %s
+
+gpu.module @test_kernel {
+  //CHECK-LABEL: load_store_matrix_1
+  gpu.func @load_store_matrix_1(%arg0: memref<4096xi8, 3>) -> vector<8xf32> {
+    %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf32>
+    //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> vector<8xf32>
+    %tid_x = gpu.thread_id x
+    %c0 = arith.constant 0 : index
+    %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> vector<8xf32>
+    gpu.return %1: vector<8xf32>
+  }
+
+  // e.g. for mem_desc<32x32xf16, @block=[16, 16], @strides=[1, 16]>
+  // its memory layout tuple is ([2,2,16,16],[256,512,1,16])
+
+  //CHECK-LABEL: load_store_matrix_2
+  gpu.func @load_store_matrix_2(%arg0: memref<4096xi8, 3>) -> vector<8xf32> {
+    %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf16, #xegpu.mem_layout<stride = [1, 16], block = [16, 16]>>
+    //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> vector<8xf32>
+    %tid_x = gpu.thread_id x
+    %c0 = arith.constant 0 : index
+    %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> vector<8xf32>
+    gpu.return %1: vector<8xf32>
+  }
+
+  // e.g. for mem_desc<32x32xf16, @block=[16, 16]>
+  // its memory layout tuple is ([2,2,16,16],[512,256,16,1])
+  //CHECK-LABEL: load_store_matrix_3
+  gpu.func @load_store_matrix_3(%arg0: memref<4096xi8, 3>) -> vector<8xf32> {
+    %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf16, #xegpu.mem_layout<block = [16, 16]>>
+    //CHECK-COUNT-8: xegpu.load_matrix {{.*}} : !xegpu.mem_desc<32x32xf32>, index, index -> vector<8x16xf32>
+    //CHECK-COUNT-8: vector.insert_strided_slice {{.*}} : vector<8x16xf32> into vector<32x32xf32>
+    %tid_x = gpu.thread_id x
+    %c0 = arith.constant 0 : index
+    %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> vector<8xf32>
+    gpu.return %1: vector<8xf32>
+  }
+
+}

>From 554b95edf3079fee2ac91ccd22078886244724f0 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 7 Oct 2025 23:17:20 +0000
Subject: [PATCH 02/12] add attributes

---
 .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td |   3 +
 .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp    |  63 +++--
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    | 262 +++++++++---------
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        |   7 +-
 .../XeGPUToXeVM/loadstore_matrix.mlir         |  47 ++--
 mlir/test/Dialect/XeGPU/ops.mlir              |  57 +++-
 6 files changed, 253 insertions(+), 186 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 32d21bae8cd34..a0a8669baf90d 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -1302,6 +1302,9 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
   let arguments = (ins XeGPU_MemDesc:$mem_desc,
     Variadic<Index>: $offsets,
     DenseI64ArrayAttr: $const_offsets,
+    OptionalAttr<I32Attr>:$vec_length,
+    OptionalAttr<MatrixAccessDirectionAttr>:$vec_direction,
+    OptionalAttr<UnitAttr>:$subgroupBlockIO,
     OptionalAttr<DistributeLayoutAttr>:$layout
   );
   let results = (outs XeGPU_ValueType:$res);
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 666df293bb8be..97deca167204a 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -371,7 +371,8 @@ static Value addOffset(ConversionPatternRewriter &rewriter, Location loc,
                        Value baseAddr, Value offset, int64_t elemByteSize) {
   Value byteSize = arith::ConstantIntOp::create(
       rewriter, loc, rewriter.getI64Type(), elemByteSize);
-  offset = arith::IndexCastUIOp::create(rewriter, loc, rewriter.getI64Type(), offset);
+  offset = arith::IndexCastUIOp::create(rewriter, loc, rewriter.getI64Type(),
+                                        offset);
   Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize);
   Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset);
   return newAddr;
@@ -513,29 +514,36 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
 // on Xe2 and Xe3 operate on 32-bit or 64-bit units, all data types smaller than
 // 32 bits will be converted to 32 bits.
 class CreateMemDescOpPattern final
-  : public OpConversionPattern<xegpu::CreateMemDescOp> {
+    : public OpConversionPattern<xegpu::CreateMemDescOp> {
 public:
   using OpConversionPattern<xegpu::CreateMemDescOp>::OpConversionPattern;
   LogicalResult
   matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
-          ConversionPatternRewriter &rewriter) const override {
-  // DEBUG: Print operation and types
-  LLVM_DEBUG(llvm::dbgs() << "[XeGPUToXeVM] Lowering CreateMemDescOp: " << op << "\n");
-  TypedValue<MemRefType> src = op.getSource();
-  auto resTy = cast<xegpu::MemDescType>(op.getResult().getType());
- 
-  // Create the result MemRefType with the same shape, element type, and memory space
-  auto newResTy = getTypeConverter()->convertType<MemRefType>(resTy);
-  
-  LLVM_DEBUG(llvm::dbgs() << "[XeGPUToXeVM] Source MemRefType: " << src.getType() << "\n");
-  LLVM_DEBUG(llvm::dbgs() << "[XeGPUToXeVM] Result MemDescType: " << resTy << "\n");
-  LLVM_DEBUG(llvm::dbgs() << "[XeGPUToXeVM] Converted MemRefType: " << newResTy << "\n");
-  Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0);
-  auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy, Value(src), zero,
-                        ValueRange());
-  rewriter.replaceOp(op, viewOp);
-  LLVM_DEBUG(llvm::dbgs() << "[XeGPUToXeVM] Replaced CreateMemDescOp with memref::ViewOp\n");
-  return success();
+                  ConversionPatternRewriter &rewriter) const override {
+    // DEBUG: Print operation and types
+    LLVM_DEBUG(llvm::dbgs()
+               << "[XeGPUToXeVM] Lowering CreateMemDescOp: " << op << "\n");
+    TypedValue<MemRefType> src = op.getSource();
+    auto resTy = cast<xegpu::MemDescType>(op.getResult().getType());
+
+    // Create the result MemRefType with the same shape, element type, and
+    // memory space
+    auto newResTy = getTypeConverter()->convertType<MemRefType>(resTy);
+
+    LLVM_DEBUG(llvm::dbgs()
+               << "[XeGPUToXeVM] Source MemRefType: " << src.getType() << "\n");
+    LLVM_DEBUG(llvm::dbgs()
+               << "[XeGPUToXeVM] Result MemDescType: " << resTy << "\n");
+    LLVM_DEBUG(llvm::dbgs()
+               << "[XeGPUToXeVM] Converted MemRefType: " << newResTy << "\n");
+    Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0);
+    auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy,
+                                         Value(src), zero, ValueRange());
+    rewriter.replaceOp(op, viewOp);
+    LLVM_DEBUG(
+        llvm::dbgs()
+        << "[XeGPUToXeVM] Replaced CreateMemDescOp with memref::ViewOp\n");
+    return success();
   }
 };
 
@@ -551,7 +559,6 @@ class MemDescSubviewOpPattern final
   }
 };
 
-
 template <typename OpType,
           typename = std::enable_if_t<llvm::is_one_of<
               OpType, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
@@ -577,7 +584,8 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
       data = adaptor.getData();
     VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType());
 
-    int64_t elemBitWidth = valOrResVecTy.getElementType().getIntOrFloatBitWidth();
+    int64_t elemBitWidth =
+        valOrResVecTy.getElementType().getIntOrFloatBitWidth();
     // Element type must be multiple of 8 bits.
     if (elemBitWidth % 8 != 0)
       return rewriter.notifyMatchFailure(
@@ -589,14 +597,17 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
         ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::SLM));
 
     auto mdescTy = cast<xegpu::MemDescType>(mdescVal.getType());
-    
-    Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, basePtrStruct);
+
+    Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create(
+        rewriter, loc, basePtrStruct);
 
     // Convert base pointer (ptr) to i64
-    Value basePtrI64 = arith::IndexCastUIOp::create(rewriter, loc, rewriter.getI64Type(), basePtrLLVM);
+    Value basePtrI64 = arith::IndexCastUIOp::create(
+        rewriter, loc, rewriter.getI64Type(), basePtrLLVM);
 
     Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets);
-    basePtrI64 = addOffset(rewriter, loc, basePtrI64, linearOffset, elemByteSize);
+    basePtrI64 =
+        addOffset(rewriter, loc, basePtrI64, linearOffset, elemByteSize);
 
     // convert base pointer (i64) to LLVM pointer type
     basePtrLLVM =
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index c64699c12cf4a..3cbb39ee9b144 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -777,168 +777,176 @@ SmallVector<OpFoldResult> getBlockedOffsets(OpBuilder &builder, Location loc,
   return blockedOffsets;
 }
 
-// Get strides as vector of integer for MemDesc. 
+// Get strides as vector of integer for MemDesc.
 SmallVector<int64_t> MemDescType::getStrides() {
 
-      SmallVector<int64_t> matrixShape(getShape().begin(),
-                      getShape().end());
+  SmallVector<int64_t> matrixShape(getShape().begin(), getShape().end());
 
-      ArrayAttr strideAttr = getStridesAttr();
-      SmallVector<int64_t> strides;
-      for (Attribute attr : strideAttr.getValue()) {
-      strides.push_back(cast<IntegerAttr>(attr).getInt());
-      }
+  ArrayAttr strideAttr = getStridesAttr();
+  SmallVector<int64_t> strides;
+  for (Attribute attr : strideAttr.getValue()) {
+    strides.push_back(cast<IntegerAttr>(attr).getInt());
+  }
 
-      llvm::dbgs() << "DEBUG: matrixShape = [";
-      for (size_t i = 0; i < matrixShape.size(); ++i) {
-      llvm::dbgs() << matrixShape[i];
-      if (i < matrixShape.size() - 1) llvm::dbgs() << ", ";
-      }
-      llvm::dbgs() << "]\n";
+  llvm::dbgs() << "DEBUG: matrixShape = [";
+  for (size_t i = 0; i < matrixShape.size(); ++i) {
+    llvm::dbgs() << matrixShape[i];
+    if (i < matrixShape.size() - 1)
+      llvm::dbgs() << ", ";
+  }
+  llvm::dbgs() << "]\n";
 
-      llvm::dbgs() << "DEBUG: strides = [";
-      for (size_t i = 0; i < strides.size(); ++i) {
-      llvm::dbgs() << strides[i];
-      if (i < strides.size() - 1) llvm::dbgs() << ", ";
-      }
-      llvm::dbgs() << "]\n";
+  llvm::dbgs() << "DEBUG: strides = [";
+  for (size_t i = 0; i < strides.size(); ++i) {
+    llvm::dbgs() << strides[i];
+    if (i < strides.size() - 1)
+      llvm::dbgs() << ", ";
+  }
+  llvm::dbgs() << "]\n";
+
+  SmallVector<int64_t> innerBlkShape = getBlockSize();
+  llvm::dbgs() << "DEBUG: innerBlkShape = [";
+  for (size_t i = 0; i < innerBlkShape.size(); ++i) {
+    llvm::dbgs() << innerBlkShape[i];
+    if (i < innerBlkShape.size() - 1)
+      llvm::dbgs() << ", ";
+  }
+  llvm::dbgs() << "]\n";
 
-      SmallVector<int64_t> innerBlkShape = getBlockSize();
-      llvm::dbgs() << "DEBUG: innerBlkShape = [";
-      for (size_t i = 0; i < innerBlkShape.size(); ++i) {
-      llvm::dbgs() << innerBlkShape[i];
-      if (i < innerBlkShape.size() - 1) llvm::dbgs() << ", ";
-      }
-      llvm::dbgs() << "]\n";
+  if (innerBlkShape.empty())
+    return strides;
 
-      if (innerBlkShape.empty())
-      return strides;
+  SmallVector<int, 4> perm =
+      llvm::to_vector<4>(llvm::seq<int>(0, strides.size()));
+  llvm::sort(perm, [&](int a, int b) { return strides[a] < strides[b]; });
 
-      SmallVector<int, 4> perm = llvm::to_vector<4>(
-                    llvm::seq<int>(0, strides.size()));
-      llvm::sort(perm, [&](int a, int b) { return strides[a] < strides[b]; });
+  llvm::dbgs() << "DEBUG: perm = [";
+  for (size_t i = 0; i < perm.size(); ++i) {
+    llvm::dbgs() << perm[i];
+    if (i < perm.size() - 1)
+      llvm::dbgs() << ", ";
+  }
+  llvm::dbgs() << "]\n";
 
-      llvm::dbgs() << "DEBUG: perm = [";
-      for (size_t i = 0; i < perm.size(); ++i) {
-      llvm::dbgs() << perm[i];
-      if (i < perm.size() - 1) llvm::dbgs() << ", ";
-      }
-      llvm::dbgs() << "]\n";
+  assert(strides[perm[0]] == 1 && "inner most dim must have stride 1");
 
-      assert(strides[perm[0]] == 1 && "inner most dim must have stride 1");
+  SmallVector<int64_t> innerBlkStride = computeStrides(innerBlkShape);
 
-      SmallVector<int64_t> innerBlkStride = computeStrides(innerBlkShape);
-      
-      llvm::dbgs() << "DEBUG: innerBlkStride = [";
-      for (size_t i = 0; i < innerBlkStride.size(); ++i) {
-      llvm::dbgs() << innerBlkStride[i];
-      if (i < innerBlkStride.size() - 1) llvm::dbgs() << ", ";
-      }
-      llvm::dbgs() << "]\n";
-
-      // compute the original matrix shape using the stride info
-      // and compute the number of blocks in each dimension
-      // The shape of highest dim can't be derived from stride info,
-      // but doesn't impact the stride computation for blocked layout.
-      SmallVector<int64_t> matrixShapeOrig(matrixShape.size());
-      SmallVector<int64_t> BlkShapeOrig(matrixShape.size());
-      for (size_t i = 0; i < perm.size() - 1; ++i) {
-      matrixShapeOrig[perm[i]] = strides[perm[i + 1]] / strides[perm[i]];
-      BlkShapeOrig[perm[i]] = matrixShapeOrig[perm[i]] / innerBlkShape[perm[i]];
-      }
+  llvm::dbgs() << "DEBUG: innerBlkStride = [";
+  for (size_t i = 0; i < innerBlkStride.size(); ++i) {
+    llvm::dbgs() << innerBlkStride[i];
+    if (i < innerBlkStride.size() - 1)
+      llvm::dbgs() << ", ";
+  }
+  llvm::dbgs() << "]\n";
+
+  // compute the original matrix shape using the stride info
+  // and compute the number of blocks in each dimension
+  // The shape of highest dim can't be derived from stride info,
+  // but doesn't impact the stride computation for blocked layout.
+  SmallVector<int64_t> matrixShapeOrig(matrixShape.size());
+  SmallVector<int64_t> BlkShapeOrig(matrixShape.size());
+  for (size_t i = 0; i < perm.size() - 1; ++i) {
+    matrixShapeOrig[perm[i]] = strides[perm[i + 1]] / strides[perm[i]];
+    BlkShapeOrig[perm[i]] = matrixShapeOrig[perm[i]] / innerBlkShape[perm[i]];
+  }
 
-      llvm::dbgs() << "DEBUG: matrixShapeOrig = [";
-      for (size_t i = 0; i < matrixShapeOrig.size(); ++i) {
-      llvm::dbgs() << matrixShapeOrig[i];
-      if (i < matrixShapeOrig.size() - 1) llvm::dbgs() << ", ";
-      }
-      llvm::dbgs() << "]\n";
+  llvm::dbgs() << "DEBUG: matrixShapeOrig = [";
+  for (size_t i = 0; i < matrixShapeOrig.size(); ++i) {
+    llvm::dbgs() << matrixShapeOrig[i];
+    if (i < matrixShapeOrig.size() - 1)
+      llvm::dbgs() << ", ";
+  }
+  llvm::dbgs() << "]\n";
 
-      llvm::dbgs() << "DEBUG: BlkShapeOrig = [";
-      for (size_t i = 0; i < BlkShapeOrig.size(); ++i) {
-      llvm::dbgs() << BlkShapeOrig[i];
-      if (i < BlkShapeOrig.size() - 1) llvm::dbgs() << ", ";
-      }
-      llvm::dbgs() << "]\n";
+  llvm::dbgs() << "DEBUG: BlkShapeOrig = [";
+  for (size_t i = 0; i < BlkShapeOrig.size(); ++i) {
+    llvm::dbgs() << BlkShapeOrig[i];
+    if (i < BlkShapeOrig.size() - 1)
+      llvm::dbgs() << ", ";
+  }
+  llvm::dbgs() << "]\n";
 
-      int64_t innerBlkSize = 1;
-      for (auto s : innerBlkShape)
-      innerBlkSize *= s;
+  int64_t innerBlkSize = 1;
+  for (auto s : innerBlkShape)
+    innerBlkSize *= s;
 
-      llvm::dbgs() << "DEBUG: innerBlkSize = " << innerBlkSize << "\n";
+  llvm::dbgs() << "DEBUG: innerBlkSize = " << innerBlkSize << "\n";
 
-      SmallVector<int64_t> outerBlkStride(matrixShape.size());
-      outerBlkStride[perm[0]] = innerBlkSize;
-      for (size_t i = 0; i < perm.size() - 1; ++i) {
-      outerBlkStride[perm[i + 1]] =
+  SmallVector<int64_t> outerBlkStride(matrixShape.size());
+  outerBlkStride[perm[0]] = innerBlkSize;
+  for (size_t i = 0; i < perm.size() - 1; ++i) {
+    outerBlkStride[perm[i + 1]] =
         outerBlkStride[perm[i]] * BlkShapeOrig[perm[i]];
-      }
-
-      llvm::dbgs() << "DEBUG: outerBlkStride = [";
-      for (size_t i = 0; i < outerBlkStride.size(); ++i) {
-      llvm::dbgs() << outerBlkStride[i];
-      if (i < outerBlkStride.size() - 1) llvm::dbgs() << ", ";
-      }
-      llvm::dbgs() << "]\n";
-
-      // combine the inner and outer strides
-      SmallVector<int64_t> blockedStrides;
-      blockedStrides.append(outerBlkStride.begin(), outerBlkStride.end());
-      blockedStrides.append(innerBlkStride.begin(), innerBlkStride.end());
+  }
 
-      llvm::dbgs() << "DEBUG: blockedStrides = [";
-      for (size_t i = 0; i < blockedStrides.size(); ++i) {
-      llvm::dbgs() << blockedStrides[i];
-      if (i < blockedStrides.size() - 1) llvm::dbgs() << ", ";
-      }
-      llvm::dbgs() << "]\n";
+  llvm::dbgs() << "DEBUG: outerBlkStride = [";
+  for (size_t i = 0; i < outerBlkStride.size(); ++i) {
+    llvm::dbgs() << outerBlkStride[i];
+    if (i < outerBlkStride.size() - 1)
+      llvm::dbgs() << ", ";
+  }
+  llvm::dbgs() << "]\n";
+
+  // combine the inner and outer strides
+  SmallVector<int64_t> blockedStrides;
+  blockedStrides.append(outerBlkStride.begin(), outerBlkStride.end());
+  blockedStrides.append(innerBlkStride.begin(), innerBlkStride.end());
+
+  llvm::dbgs() << "DEBUG: blockedStrides = [";
+  for (size_t i = 0; i < blockedStrides.size(); ++i) {
+    llvm::dbgs() << blockedStrides[i];
+    if (i < blockedStrides.size() - 1)
+      llvm::dbgs() << ", ";
+  }
+  llvm::dbgs() << "]\n";
 
-      return blockedStrides;
-    }
+  return blockedStrides;
+}
 
 // Calculate the linear offset using the blocked offsets and stride
 Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc,
-                  ArrayRef<OpFoldResult> offsets) {
+                                    ArrayRef<OpFoldResult> offsets) {
 
   SmallVector<int64_t> blockShape = getBlockSize();
   SmallVector<int64_t> strides = getStrides();
-  
+
   LLVM_DEBUG(llvm::dbgs() << "getLinearOffsets: blockShape=[";
-       llvm::interleaveComma(blockShape, llvm::dbgs());
-       llvm::dbgs() << "], strides=[";
-       llvm::interleaveComma(strides, llvm::dbgs());
-       llvm::dbgs() << "]\n");
-  
-  if (!blockShape.empty()) {
-  assert(offsets.size() == blockShape.size() &&
-       "offsets and blockShape must have the same size");
-  // say the original offset is [y, x], and the block shape is [By, Bx],
-  // then the blocked offset is [y/By, x/Bx, y%By, x%Bx]
-  SmallVector<OpFoldResult> blockedOffsets;
-  SmallVector<OpFoldResult> divs, rems;
+             llvm::interleaveComma(blockShape, llvm::dbgs());
+             llvm::dbgs() << "], strides=[";
+             llvm::interleaveComma(strides, llvm::dbgs());
+             llvm::dbgs() << "]\n");
 
-  for (auto [offset, block] : llvm::zip(offsets, blockShape)) {
-    divs.push_back(div(offset, block));
-    rems.push_back(rem(offset, block));
-  }
-  blockedOffsets.append(divs.begin(), divs.end());
-  blockedOffsets.append(rems.begin(), rems.end());
+  if (!blockShape.empty()) {
+    assert(offsets.size() == blockShape.size() &&
+           "offsets and blockShape must have the same size");
+    // say the original offset is [y, x], and the block shape is [By, Bx],
+    // then the blocked offset is [y/By, x/Bx, y%By, x%Bx]
+    SmallVector<OpFoldResult> blockedOffsets;
+    SmallVector<OpFoldResult> divs, rems;
+
+    for (auto [offset, block] : llvm::zip(offsets, blockShape)) {
+      divs.push_back(div(offset, block));
+      rems.push_back(rem(offset, block));
+    }
+    blockedOffsets.append(divs.begin(), divs.end());
+    blockedOffsets.append(rems.begin(), rems.end());
 
-  offsets = blockedOffsets;
-  LLVM_DEBUG(llvm::dbgs() << "getLinearOffsets: blocked offsets size=" 
-              << offsets.size() << "\n");
+    offsets = blockedOffsets;
+    LLVM_DEBUG(llvm::dbgs() << "getLinearOffsets: blocked offsets size="
+                            << offsets.size() << "\n");
   }
 
   // Start with initial value as matrix descriptor's base offset.
   Value linearOffset = arith::ConstantIndexOp::create(builder, loc, 0);
   for (size_t i = 0; i < offsets.size(); ++i) {
-  OpFoldResult mulResult = mul(offsets[i], strides[i]);
-  Value mulVal = getValueOrCreateConstantIndexOp(builder, loc, mulResult);
-  linearOffset = arith::AddIOp::create(builder, loc, mulVal, linearOffset);
+    OpFoldResult mulResult = mul(offsets[i], strides[i]);
+    Value mulVal = getValueOrCreateConstantIndexOp(builder, loc, mulResult);
+    linearOffset = arith::AddIOp::create(builder, loc, mulVal, linearOffset);
   }
 
-  LLVM_DEBUG(llvm::dbgs() << "getLinearOffsets: final linearOffset=" 
-              << linearOffset << "\n");
+  LLVM_DEBUG(llvm::dbgs() << "getLinearOffsets: final linearOffset="
+                          << linearOffset << "\n");
 
   return linearOffset;
 }
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 23e487787652d..c40d5a42fb6e5 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -1049,8 +1049,11 @@ void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res,
   llvm::SmallVector<int64_t> staticOffsets;
   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
   auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
+  // Call the generated builder with all parameters (including optional ones as
+  // nullptr/empty)
   build(builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr,
-        layout);
+        /*vec_length=*/nullptr, /*vec_direction=*/nullptr,
+        /*subgroupBlockIO=*/nullptr, layout);
 }
 
 LogicalResult LoadMatrixOp::verify() {
@@ -1097,7 +1100,7 @@ LogicalResult StoreMatrixOp::verify() {
   ArrayRef<int64_t> mdescShape = mdescTy.getShape();
   if (dataShape.size() != 1) {
     if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
-                    [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
+                     [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
       return emitOpError("data shape must not exceed mem_desc shape.");
   }
   return success();
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
index 30d6274c9dccf..7b87f32b876fe 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
@@ -1,40 +1,41 @@
 // RUN: mlir-opt  -split-input-file -convert-xegpu-to-xevm --cse --canonicalize %s | FileCheck %s
 
 gpu.module @test_kernel {
+
+  // e.g. for mem_desc<32x32xf16, @strides=[1, 16]>
+  // its memory layout tuple is (blocked shape = [1,1,32,32],strides=[1024,1024,32,1])
   //CHECK-LABEL: load_store_matrix_1
-  gpu.func @load_store_matrix_1(%arg0: memref<4096xi8, 3>) -> vector<8xf32> {
+  gpu.func @load_store_matrix_1(%arg0: memref<4096xi8, 3>) -> vector<1xf32> {
     %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf32>
-    //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> vector<8xf32>
+    //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> vector<1xf32>
     %tid_x = gpu.thread_id x
     %c0 = arith.constant 0 : index
-    %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> vector<8xf32>
-    gpu.return %1: vector<8xf32>
+    %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> vector<1xf32>
+    gpu.return %1: vector<1xf32>
   }
 
-  // e.g. for mem_desc<32x32xf16, @block=[16, 16], @strides=[1, 16]>
-  // its memory layout tuple is ([2,2,16,16],[256,512,1,16])
-
+  // e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 16]>
+  // its memory layout tuple is ([2,4,16,16],[256,512,1,16])
   //CHECK-LABEL: load_store_matrix_2
-  gpu.func @load_store_matrix_2(%arg0: memref<4096xi8, 3>) -> vector<8xf32> {
-    %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf16, #xegpu.mem_layout<stride = [1, 16], block = [16, 16]>>
-    //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> vector<8xf32>
+  gpu.func @load_store_matrix_2(%arg0: memref<4096xi8, 3>) -> vector<1xf16> {
+    %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>
+    //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> vector<1xf16>
     %tid_x = gpu.thread_id x
-    %c0 = arith.constant 0 : index
-    %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> vector<8xf32>
-    gpu.return %1: vector<8xf32>
+    %c13 = arith.constant 13 : index
+    %1 = xegpu.load_matrix %0[%c13, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> vector<1xf16>
+    gpu.return %1: vector<1xf16>
   }
 
-  // e.g. for mem_desc<32x32xf16, @block=[16, 16]>
-  // its memory layout tuple is ([2,2,16,16],[512,256,16,1])
+  // e.g. for mem_desc<32x64xf16, @block=[16, 16]>
+  // its memory layout tuple is ([2,4,16,16],[1024,256,16,1])
   //CHECK-LABEL: load_store_matrix_3
-  gpu.func @load_store_matrix_3(%arg0: memref<4096xi8, 3>) -> vector<8xf32> {
-    %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf16, #xegpu.mem_layout<block = [16, 16]>>
-    //CHECK-COUNT-8: xegpu.load_matrix {{.*}} : !xegpu.mem_desc<32x32xf32>, index, index -> vector<8x16xf32>
-    //CHECK-COUNT-8: vector.insert_strided_slice {{.*}} : vector<8x16xf32> into vector<32x32xf32>
+  gpu.func @load_store_matrix_3(%arg0: memref<4096xi8, 3>) -> vector<1xf16> {
+    %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
+    //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> vector<1xf16>
     %tid_x = gpu.thread_id x
-    %c0 = arith.constant 0 : index
-    %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> vector<8xf32>
-    gpu.return %1: vector<8xf32>
+    %c17 = arith.constant 17 : index
+    %1 = xegpu.load_matrix %0[%c17, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> vector<1xf16>
+    gpu.return %1: vector<1xf16>
   }
 
-}
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index bb379024a34d7..47aa05763ee99 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -825,35 +825,76 @@ gpu.func @create_mem_desc_with_stride() {
   gpu.return
 }
 
-// CHECK: gpu.func @load_mem_desc([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>)
-gpu.func @load_mem_desc(%arg0: !xegpu.mem_desc<16x64xf16>) {
+// CHECK: gpu.func @load_matrix([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>)
+gpu.func @load_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) {
   // CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> vector<8x16xf16>
   %data = xegpu.load_matrix %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> vector<8x16xf16>
   gpu.return
 }
 
-// CHECK: gpu.func @load_mem_desc_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>)
-gpu.func @load_mem_desc_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>) {
+// CHECK: gpu.func @load_matrix_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>)
+gpu.func @load_matrix_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>) {
   // CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8x16xf16>
   %data = xegpu.load_matrix %arg0[8, 8]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8x16xf16>
   gpu.return
 }
 
+// CHECK: gpu.func @simt_load_matrix(%arg0: !xegpu.mem_desc<16x64xf16>)
+gpu.func @simt_load_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) {
+  // CHECK: xegpu.load_matrix [[ARG0]][8, 16] : !xegpu.mem_desc<16x64xf16> -> vector<1xf16>
+  %data = xegpu.load_matrix %arg0[8, 16]: !xegpu.mem_desc<16x64xf16> -> vector<1xf16>
+  gpu.return
+}
+
+// CHECK: gpu.func @simt_load_matrix_subgroupblockIO(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>)
+gpu.func @simt_load_matrix_subgroupblockIO(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>) {
+  // CHECK: xegpu.load_matrix [[ARG0]][8, 16] <{subgroupBlockIO}>: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>> -> vector<8xf16>
+  %data = xegpu.load_matrix %arg0[8, 16] {subgroupBlockIO}: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>> -> vector<8xf16>
+  gpu.return
+}
 
-// CHECK: gpu.func @store_mem_desc([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>, [[ARG1:%.+]]: vector<16x16xf16>)
-gpu.func @store_mem_desc(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<16x16xf16>) {
+// CHECK: gpu.func @simt_load_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>)
+gpu.func @simt_load_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>) {
+  // CHECK: xegpu.load_matrix [[ARG0]][8, 8] <{vec_direction = #xegpu.matrix_access_direction<col>, vec_length = 8 : i32}>: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8xf16> 
+  %data = xegpu.load_matrix %arg0[8, 8]{vec_direction = #xegpu.matrix_access_direction<col>, vec_length = 8 : i32}: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8xf16>
+  gpu.return
+}
+
+// CHECK: gpu.func @store_matrix([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>, [[ARG1:%.+]]: vector<16x16xf16>)
+gpu.func @store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<16x16xf16>) {
   // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 8] : vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
   xegpu.store_matrix %arg1, %arg0[8, 8]: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
   gpu.return
 }
 
-// CHECK: gpu.func @store_mem_desc_with_stride([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, [[ARG1:%.+]]: vector<16x16xf16>)
-gpu.func @store_mem_desc_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, %arg1: vector<16x16xf16>) {
+// CHECK: gpu.func @store_matrix_with_stride([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, [[ARG1:%.+]]: vector<16x16xf16>)
+gpu.func @store_matrix_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, %arg1: vector<16x16xf16>) {
   // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][0, 8] : vector<16x16xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
   xegpu.store_matrix %arg1, %arg0[0, 8]: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
   gpu.return
 }
 
+// CHECK: gpu.func @simt_store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<1xf16>) { 
+gpu.func @simt_store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<1xf16>) {
+  // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 16] : vector<1xf16>, !xegpu.mem_desc<16x64xf16>
+  xegpu.store_matrix %arg1, %arg0[8, 16]: vector<1xf16>, !xegpu.mem_desc<16x64xf16>
+  gpu.return
+}
+
+// CHECK: gpu.func @simt_store_matrix_subgroupblockIO(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>, %arg1: vector<8xf16>)
+gpu.func @simt_store_matrix_subgroupblockIO(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>, %arg1: vector<8xf16>) {
+  // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 16] {subgroupBlockIO}: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>
+  xegpu.store_matrix %arg1, %arg0[8, 16] {subgroupBlockIO}: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>
+  gpu.return
+}
+
+// CHECK: gpu.func @simt_store_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, %arg1: vector<8xf16>) {
+gpu.func @simt_store_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, %arg1: vector<8xf16>) {
+  // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 8] {vec_direction = #xegpu.matrix_access_direction<col>, vec_length = 8 : i32}: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> 
+  xegpu.store_matrix %arg1, %arg0[8, 8] {vec_length = 8 : i32, vec_direction = #xegpu.matrix_access_direction<col>}: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
+  gpu.return
+}
+
 // CHECK: gpu.func @mem_desc_subview([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>)
 gpu.func @mem_desc_subview(%arg0: !xegpu.mem_desc<16x64xf16>) {
   //CHECK: xegpu.mem_desc_subview [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout<stride = [64, 1]>>

>From 446b951f2ed0bffd8be64955b7c4e5a94d5e2eb7 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 8 Oct 2025 22:49:43 +0000
Subject: [PATCH 03/12] add tests and refactoring

---
 .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 19 +++-
 .../lib/Conversion/XeGPUToXeVM/CMakeLists.txt |  1 +
 .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp    | 80 ++++++++++++++--
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    | 12 ++-
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        | 92 +++++++++++++------
 .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp  |  4 +-
 .../Transforms/XeGPUWgToSgDistribute.cpp      |  2 +-
 mlir/test/Conversion/XeGPUToXeVM/dpas.mlir    |  2 +-
 .../XeGPUToXeVM/loadstore_matrix.mlir         | 54 ++++++++---
 mlir/test/Dialect/XeGPU/ops.mlir              | 22 ++---
 10 files changed, 211 insertions(+), 77 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index a0a8669baf90d..044a8ef22d891 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -1304,10 +1304,10 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
     DenseI64ArrayAttr: $const_offsets,
     OptionalAttr<I32Attr>:$vec_length,
     OptionalAttr<MatrixAccessDirectionAttr>:$vec_direction,
-    OptionalAttr<UnitAttr>:$subgroupBlockIO,
+    OptionalAttr<UnitAttr>:$subgroup_block_io,
     OptionalAttr<DistributeLayoutAttr>:$layout
   );
-  let results = (outs XeGPU_ValueType:$res);
+  let results = (outs AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$res);  
   let assemblyFormat = [{
     $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
     prop-dict attr-dict `` `:` type(operands) `->` type(results)
@@ -1338,7 +1338,10 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
     }
 
     ArrayRef<int64_t> getDataShape() {
-      return getRes().getType().getShape();
+      auto resTy = getRes().getType();
+      if (auto vecTy = llvm::dyn_cast<VectorType>(resTy))
+        return vecTy.getShape();
+      return {};
     }
   }];
 
@@ -1348,10 +1351,13 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
 def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
                               AllElementTypesMatch<["mem_desc", "data"]>]> {
   let arguments = (ins
-    XeGPU_ValueType:$data,
+    AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$data,
     XeGPU_MemDesc:$mem_desc,
     Variadic<Index>: $offsets,
     DenseI64ArrayAttr: $const_offsets,
+    OptionalAttr<I32Attr>:$vec_length,
+    OptionalAttr<MatrixAccessDirectionAttr>:$vec_direction,
+    OptionalAttr<UnitAttr>:$subgroup_block_io,
     OptionalAttr<DistributeLayoutAttr>:$layout
   );
   let assemblyFormat = [{ $data `,` $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
@@ -1379,7 +1385,10 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
     }
 
     ArrayRef<int64_t> getDataShape() {
-      return getData().getType().getShape();
+      auto DataTy = getData().getType();
+      if (auto vecTy = llvm::dyn_cast<VectorType>(DataTy))
+        return vecTy.getShape();
+      return {};
     }
 
   }];
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt b/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt
index 84b25809f1ed0..dd9edc43a1657 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt
@@ -21,6 +21,7 @@ add_mlir_conversion_library(MLIRXeGPUToXeVM
   MLIRIndexDialect
   MLIRSCFDialect
   MLIRXeGPUDialect
+  MLIRXeGPUUtils
   MLIRPass
   MLIRTransforms
   MLIRSCFTransforms
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 97deca167204a..f4f0a46c54089 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -21,6 +21,7 @@
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/Transforms/Patterns.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/Support/FormatVariadic.h"
@@ -371,8 +372,6 @@ static Value addOffset(ConversionPatternRewriter &rewriter, Location loc,
                        Value baseAddr, Value offset, int64_t elemByteSize) {
   Value byteSize = arith::ConstantIntOp::create(
       rewriter, loc, rewriter.getI64Type(), elemByteSize);
-  offset = arith::IndexCastUIOp::create(rewriter, loc, rewriter.getI64Type(),
-                                        offset);
   Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize);
   Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset);
   return newAddr;
@@ -583,6 +582,8 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
     else
       data = adaptor.getData();
     VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType());
+    if (!valOrResVecTy)
+      valOrResVecTy = VectorType::get(1, data.getType());
 
     int64_t elemBitWidth =
         valOrResVecTy.getElementType().getIntOrFloatBitWidth();
@@ -606,6 +607,8 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
         rewriter, loc, rewriter.getI64Type(), basePtrLLVM);
 
     Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets);
+    linearOffset = arith::IndexCastUIOp::create(
+        rewriter, loc, rewriter.getI64Type(), linearOffset);
     basePtrI64 =
         addOffset(rewriter, loc, basePtrI64, linearOffset, elemByteSize);
 
@@ -613,15 +616,72 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
     basePtrLLVM =
         LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
 
-    if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
-
-      Value loadOp =
-          LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM);
-      rewriter.replaceOp(op, loadOp);
+    // if the size of valOrResVecTy is 1, it lowers to a scalar load/store
+    // operation. LLVM load/store does not support vector of size 1, so we need
+    // to handle this case separately.
+    if (valOrResVecTy.getNumElements() == 1) {
+      Type scalarTy = valOrResVecTy.getElementType();
+      if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
+        Value loadOp =
+            LLVM::LoadOp::create(rewriter, loc, scalarTy, basePtrLLVM);
+        rewriter.replaceOp(op, loadOp);
+      } else {
+        auto storeOp = LLVM::StoreOp::create(rewriter, loc, adaptor.getData(),
+                                             basePtrLLVM);
+        rewriter.eraseOp(op);
+      }
+      return success();
     } else {
-      auto storeOp =
-          LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM);
-      rewriter.eraseOp(op);
+      // if the attribute 'subgroup_block_io' is set to true, it lowers to
+      // xevm.blockload
+      auto subgroupBlockIoAttr = op.getSubgroupBlockIoAttr();
+      bool subgroup_block_io =
+          subgroupBlockIoAttr && cast<BoolAttr>(subgroupBlockIoAttr).getValue();
+      if (subgroup_block_io) {
+        if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
+          Value loadOp = xevm::BlockLoadOp::create(rewriter, loc, valOrResVecTy,
+                                                   basePtrLLVM);
+          rewriter.replaceOp(op, loadOp);
+        } else {
+          xevm::BlockStoreOp::create(rewriter, loc, basePtrLLVM,
+                                     adaptor.getData(), nullptr);
+          rewriter.eraseOp(op);
+        }
+      } else {
+        // if the result is 1D vector, if the vector direction is Column, then
+        // the
+        //  memory descriptor should be treated as column major
+        auto chipOpt = xegpu::getChipStr(op);
+        if (!chipOpt || (*chipOpt != "pvc" && *chipOpt != "bmg")) {
+          // the lowering only works for pvc and bmg
+          return rewriter.notifyMatchFailure(
+              op, "The lowering is specific to pvc or bmg.");
+        }
+        xegpu::MatrixAccessDirectionAttr vecDirection =
+            op.getVecDirectionAttr();
+        if (vecDirection &&
+            vecDirection.getValue() == xegpu::MatrixAccessDirection::COL &&
+            !mdescTy.isColMajor())
+          return rewriter.notifyMatchFailure(
+              op, "mem_desc should be column major when "
+                  "vec_direction is COLUMN for 1D result.");
+        if (vecDirection &&
+            vecDirection.getValue() == xegpu::MatrixAccessDirection::ROW &&
+            mdescTy.isColMajor())
+          return rewriter.notifyMatchFailure(
+              op, "mem_desc should be row major when "
+                  "vec_direction is ROW for 1D result.");
+
+        if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
+          Value loadOp =
+              LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM);
+          rewriter.replaceOp(op, loadOp);
+        } else {
+          auto storeOp = LLVM::StoreOp::create(rewriter, loc, adaptor.getData(),
+                                               basePtrLLVM);
+          rewriter.eraseOp(op);
+        }
+      }
     }
     return success();
   }
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 3cbb39ee9b144..26f2f691ab860 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -813,9 +813,8 @@ SmallVector<int64_t> MemDescType::getStrides() {
   }
   llvm::dbgs() << "]\n";
 
-  if (innerBlkShape.empty())
-    return strides;
-
+  // get perm from FCD to LCD
+  // perm[i] = the dim with i-th smallest stride
   SmallVector<int, 4> perm =
       llvm::to_vector<4>(llvm::seq<int>(0, strides.size()));
   llvm::sort(perm, [&](int a, int b) { return strides[a] < strides[b]; });
@@ -908,6 +907,7 @@ SmallVector<int64_t> MemDescType::getStrides() {
 Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc,
                                     ArrayRef<OpFoldResult> offsets) {
 
+  SmallVector<int64_t> matrixShape(getShape().begin(), getShape().end());
   SmallVector<int64_t> blockShape = getBlockSize();
   SmallVector<int64_t> strides = getStrides();
 
@@ -917,7 +917,11 @@ Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc,
              llvm::interleaveComma(strides, llvm::dbgs());
              llvm::dbgs() << "]\n");
 
-  if (!blockShape.empty()) {
+  // blockshape equal to matrixshape means no blocking
+  if (llvm::equal(blockShape, matrixShape)) {
+    // remove the outer dims from strides
+    strides.erase(strides.begin(), strides.begin() + matrixShape.size());
+  } else {
     assert(offsets.size() == blockShape.size() &&
            "offsets and blockShape must have the same size");
     // say the original offset is [y, x], and the block shape is [By, Bx],
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index c40d5a42fb6e5..0bc7b3f06ec53 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -173,6 +173,51 @@ isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy,
   return success();
 }
 
+LogicalResult IsValidStoreMatrixParams(
+    VectorType dataTy, MemDescType mdescTy, UnitAttr subgroup_block_io,
+    MatrixAccessDirectionAttr vecDirection, IntegerAttr vecLength,
+    function_ref<InFlightDiagnostic()> emitError) {
+
+  if (!dataTy)
+    if (subgroup_block_io || vecDirection || vecLength)
+      return emitError() << "vec_length, vec_direction and subgroup_block_io "
+                            "are only allowed when result is a 1D VectorType.";
+    else
+      return success();
+
+  if (mdescTy.getRank() != 2)
+    return emitError() << "mem_desc must be 2D.";
+
+  ArrayRef<int64_t> dataShape = dataTy.getShape();
+  ArrayRef<int64_t> mdescShape = mdescTy.getShape();
+
+  if (dataShape.size() == 2) {
+    if (subgroup_block_io || vecDirection || vecLength)
+      return emitError() << "vec_length, vec_direction and subgroup_block_io "
+                            "are only allowed when result is a 1D VectorType.";
+    if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
+                     [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
+      return emitError() << "data shape must not exceed mem_desc shape.";
+  } else if (dataShape.size() == 1) {
+
+    SmallVector<int64_t> blockSize = mdescTy.getBlockSize();
+    // if the subgroup_block_io attribute is set,  mdescTy must have block
+    // attribute
+    if (subgroup_block_io && !blockSize.size())
+      return emitError() << "mem_desc must have block attribute when "
+                            "subgroup_block_io is set.";
+    // if the subgroup_block_io attribute is set, the memdesc should be row
+    // major
+    if (subgroup_block_io && mdescTy.isColMajor())
+      return emitError() << "mem_desc should be row major when "
+                            "subgroup_block_io is set.";
+  } else if (dataShape.size() == 0) {
+    return emitError() << "result shape must not be empty.";
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // XeGPU_CreateNdDescOp
 //===----------------------------------------------------------------------===//
@@ -1053,25 +1098,20 @@ void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res,
   // nullptr/empty)
   build(builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr,
         /*vec_length=*/nullptr, /*vec_direction=*/nullptr,
-        /*subgroupBlockIO=*/nullptr, layout);
+        /*subgroup_block_io=*/nullptr, layout);
 }
 
 LogicalResult LoadMatrixOp::verify() {
-  VectorType resTy = getRes().getType();
-  MemDescType mdescTy = getMemDesc().getType();
-
-  if (mdescTy.getRank() != 2)
-    return emitOpError("mem_desc must be 2D.");
 
-  ArrayRef<int64_t> valueShape = resTy.getShape();
-  ArrayRef<int64_t> mdescShape = mdescTy.getShape();
+  auto resTy = dyn_cast<VectorType>(getRes().getType());
+  UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
+  MatrixAccessDirectionAttr vecDirection = getVecDirectionAttr();
+  IntegerAttr vecLength = getVecLengthAttr();
+  MemDescType mdescTy = getMemDesc().getType();
 
-  if (valueShape.size() != 1) {
-    if (llvm::any_of(llvm::zip_equal(valueShape, mdescShape),
-                     [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
-      return emitOpError("result shape must not exceed mem_desc shape.");
-  }
-  return success();
+  return IsValidStoreMatrixParams(resTy, mdescTy, subgroup_block_io,
+                                  vecDirection, vecLength,
+                                  [&]() { return emitError(); });
 }
 
 //===----------------------------------------------------------------------===//
@@ -1086,24 +1126,20 @@ void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data,
   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
   auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
   build(builder, state, data, memDesc, dynamicOffsets, staticOffsetsAttr,
-        layout);
+        /*vec_length=*/nullptr, /*vec_direction=*/nullptr,
+        /*subgroup_block_io=*/nullptr, layout);
 }
 
 LogicalResult StoreMatrixOp::verify() {
-  VectorType dataTy = getData().getType();
-  MemDescType mdescTy = getMemDesc().getType();
 
-  if (mdescTy.getRank() != 2)
-    return emitOpError("mem_desc must be 2D.");
-
-  ArrayRef<int64_t> dataShape = dataTy.getShape();
-  ArrayRef<int64_t> mdescShape = mdescTy.getShape();
-  if (dataShape.size() != 1) {
-    if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
-                     [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
-      return emitOpError("data shape must not exceed mem_desc shape.");
-  }
-  return success();
+  auto dataTy = dyn_cast<VectorType>(getData().getType());
+  UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
+  MatrixAccessDirectionAttr vecDirection = getVecDirectionAttr();
+  IntegerAttr vecLength = getVecLengthAttr();
+  MemDescType mdescTy = getMemDesc().getType();
+  return IsValidStoreMatrixParams(dataTy, mdescTy, subgroup_block_io,
+                                  vecDirection, vecLength,
+                                  [&]() { return emitError(); });
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index a178d0fe4b0b0..6d17b27849a43 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -941,7 +941,7 @@ struct UnrollLoadMatrixOp : public UnrollPattern<xegpu::LoadMatrixOp> {
   LogicalResult matchAndRewrite(xegpu::LoadMatrixOp op,
                                 PatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
-    VectorType valueTy = op.getType();
+    VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
     if (!targetShape || targetShape->size() != (size_t)valueTy.getRank())
       return failure();
@@ -984,7 +984,7 @@ struct UnrollStoreMatrixOp : public UnrollPattern<xegpu::StoreMatrixOp> {
       return failure();
 
     Location loc = op.getLoc();
-    VectorType valueTy = op.getData().getType();
+    VectorType valueTy = llvm::dyn_cast<VectorType>(op.getData().getType());
     ArrayRef<int64_t> shape = valueTy.getShape();
     auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr());
 
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 9413a9296b184..d57289a6b21e9 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -867,7 +867,7 @@ struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> {
       return failure();
 
     ArrayRef<int64_t> wgShape = op.getDataShape();
-    VectorType valueTy = op.getRes().getType();
+    VectorType valueTy = llvm::dyn_cast<VectorType>(op.getRes().getType());
     Type elemTy = valueTy.getElementType();
 
     xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
diff --git a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
index bbf313bf4fb60..a9ab0be00722c 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/dpas.mlir
@@ -7,7 +7,7 @@ gpu.module @test_kernel {
         // Loads are checked in a separate test.
         // CHECK: %[[D:.*]] = xevm.mma %[[ARG0]], %[[ARG1]], %[[ARG2]] {shape = <m = 8, n = 16, k = 16>, types = <d = f32, a = f16, b = f16, c = f32>}
         // CHECK-SAME:    : (vector<8xf16>, vector<16xf16>, vector<8xf32>) -> vector<8xf32>
-        %d = xegpu.dpas %a_loaded, %b_loaded, %c_loaded {a_layout = #sg_map_a_f16, b_layout = #sg_map_b_f16, c_layout = #sg_map_c_f32}
+        %d = xegpu.dpas %a_loaded, %b_loaded, %c_loaded
             : vector<8xf16>, vector<16xf16>, vector<8xf32> -> vector<8xf32>
         return %d : vector<8xf32>
     }
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
index 7b87f32b876fe..372f477219817 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
@@ -1,41 +1,65 @@
 // RUN: mlir-opt  -split-input-file -convert-xegpu-to-xevm --cse --canonicalize %s | FileCheck %s
 
-gpu.module @test_kernel {
+gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
 
   // e.g. for mem_desc<32x32xf16, @strides=[1, 16]>
   // its memory layout tuple is (blocked shape = [1,1,32,32],strides=[1024,1024,32,1])
   //CHECK-LABEL: load_store_matrix_1
-  gpu.func @load_store_matrix_1(%arg0: memref<4096xi8, 3>) -> vector<1xf32> {
+  gpu.func @load_store_matrix_1(%arg0: memref<4096xi8, 3>) -> f32 {
     %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf32>
-    //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> vector<1xf32>
+    //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f32
     %tid_x = gpu.thread_id x
     %c0 = arith.constant 0 : index
-    %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> vector<1xf32>
-    gpu.return %1: vector<1xf32>
+    %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> f32
+    gpu.return %1: f32
   }
 
-  // e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 16]>
+  // e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 32]>
   // its memory layout tuple is ([2,4,16,16],[256,512,1,16])
   //CHECK-LABEL: load_store_matrix_2
-  gpu.func @load_store_matrix_2(%arg0: memref<4096xi8, 3>) -> vector<1xf16> {
+  gpu.func @load_store_matrix_2(%arg0: memref<4096xi8, 3>) -> f16 {
     %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>
-    //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> vector<1xf16>
+    //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f16
     %tid_x = gpu.thread_id x
     %c13 = arith.constant 13 : index
-    %1 = xegpu.load_matrix %0[%c13, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> vector<1xf16>
-    gpu.return %1: vector<1xf16>
+    %1 = xegpu.load_matrix %0[%c13, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> f16
+    gpu.return %1: f16
   }
 
   // e.g. for mem_desc<32x64xf16, @block=[16, 16]>
   // its memory layout tuple is ([2,4,16,16],[1024,256,16,1])
   //CHECK-LABEL: load_store_matrix_3
-  gpu.func @load_store_matrix_3(%arg0: memref<4096xi8, 3>) -> vector<1xf16> {
+  gpu.func @load_store_matrix_3(%arg0: memref<4096xi8, 3>) -> f16 {
     %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
-    //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> vector<1xf16>
+    //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f16
+    %tid_x = gpu.thread_id x
+    %c19 = arith.constant 19: index
+    %1 = xegpu.load_matrix %0[%c19, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> f16
+    gpu.return %1: f16
+  }
+  
+  // e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 16]>
+  // its memory layout tuple is ([2,4,16,16],[256,512,1,16])
+  //CHECK-LABEL: load_store_matrix_4
+  gpu.func @load_store_matrix_4(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
+    %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>
+    //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> vector<8xf16>
     %tid_x = gpu.thread_id x
-    %c17 = arith.constant 17 : index
-    %1 = xegpu.load_matrix %0[%c17, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> vector<1xf16>
-    gpu.return %1: vector<1xf16>
+    %c16 = arith.constant 16 : index
+    %1 = xegpu.load_matrix %0[%c16, %tid_x] {vec_length = 8 : i32, vec_direction = #xegpu.matrix_access_direction<col>}: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> vector<8xf16>
+    gpu.return %1: vector<8xf16>
+  }
+
+  // e.g. for mem_desc<32x64xf16, @block=[16, 16]>
+  // its memory layout tuple is ([2,4,16,16],[1024,256,16,1])
+  //CHECK-LABEL: load_store_matrix_5
+  gpu.func @load_store_matrix_5(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
+    %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
+    //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> vector<8xf16>
+    %c16 = arith.constant 16 : index
+    %c48 = arith.constant 48 : index
+    %1 = xegpu.load_matrix %0[%c16, %c48] {subgroup_block_io}: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> vector<8xf16>
+    gpu.return %1: vector<8xf16>
   }
 
 }
\ No newline at end of file
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index 47aa05763ee99..eb5d653be8b9c 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -846,17 +846,17 @@ gpu.func @simt_load_matrix(%arg0: !xegpu.mem_desc<16x64xf16>) {
   gpu.return
 }
 
-// CHECK: gpu.func @simt_load_matrix_subgroupblockIO(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>)
-gpu.func @simt_load_matrix_subgroupblockIO(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>) {
-  // CHECK: xegpu.load_matrix [[ARG0]][8, 16] <{subgroupBlockIO}>: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>> -> vector<8xf16>
-  %data = xegpu.load_matrix %arg0[8, 16] {subgroupBlockIO}: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>> -> vector<8xf16>
+// CHECK: gpu.func @simt_load_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>)
+gpu.func @simt_load_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>) {
+  // CHECK: xegpu.load_matrix [[ARG0]][8, 16] <{subgroup_block_io}>: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>> -> vector<8xf16>
+  %data = xegpu.load_matrix %arg0[8, 16] <{subgroup_block_io}>: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>> -> vector<8xf16>
   gpu.return
 }
 
 // CHECK: gpu.func @simt_load_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>)
 gpu.func @simt_load_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>) {
   // CHECK: xegpu.load_matrix [[ARG0]][8, 8] <{vec_direction = #xegpu.matrix_access_direction<col>, vec_length = 8 : i32}>: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8xf16> 
-  %data = xegpu.load_matrix %arg0[8, 8]{vec_direction = #xegpu.matrix_access_direction<col>, vec_length = 8 : i32}: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8xf16>
+  %data = xegpu.load_matrix %arg0[8, 8] <{vec_direction = #xegpu.matrix_access_direction<col>, vec_length = 8 : i32}>: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8xf16>
   gpu.return
 }
 
@@ -881,17 +881,17 @@ gpu.func @simt_store_matrix(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<1xf
   gpu.return
 }
 
-// CHECK: gpu.func @simt_store_matrix_subgroupblockIO(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>, %arg1: vector<8xf16>)
-gpu.func @simt_store_matrix_subgroupblockIO(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>, %arg1: vector<8xf16>) {
-  // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 16] {subgroupBlockIO}: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>
-  xegpu.store_matrix %arg1, %arg0[8, 16] {subgroupBlockIO}: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>
+// CHECK: gpu.func @simt_store_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>, %arg1: vector<8xf16>)
+gpu.func @simt_store_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>, %arg1: vector<8xf16>) {
+  // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 16] <{subgroup_block_io}>: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>
+  xegpu.store_matrix %arg1, %arg0[8, 16] <{subgroup_block_io}>: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<block = [16, 16]>>
   gpu.return
 }
 
 // CHECK: gpu.func @simt_store_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, %arg1: vector<8xf16>) {
 gpu.func @simt_store_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, %arg1: vector<8xf16>) {
-  // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 8] {vec_direction = #xegpu.matrix_access_direction<col>, vec_length = 8 : i32}: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> 
-  xegpu.store_matrix %arg1, %arg0[8, 8] {vec_length = 8 : i32, vec_direction = #xegpu.matrix_access_direction<col>}: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
+  // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 8] <{vec_direction = #xegpu.matrix_access_direction<col>, vec_length = 8 : i32}>: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> 
+  xegpu.store_matrix %arg1, %arg0[8, 8] <{vec_length = 8 : i32, vec_direction = #xegpu.matrix_access_direction<col>}>: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
   gpu.return
 }
 

>From 9f9744cecbd30fea7b63c47768b323879222d105 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 9 Oct 2025 02:00:49 +0000
Subject: [PATCH 04/12] bug fixes

---
 .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp    | 43 +++++----
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    | 92 +------------------
 .../XeGPUToXeVM/loadstore_matrix.mlir         |  2 +-
 mlir/test/Dialect/XeGPU/invalid.mlir          |  2 +-
 4 files changed, 30 insertions(+), 109 deletions(-)

diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index f4f0a46c54089..67e8246e5536a 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -33,8 +33,6 @@
 
 #include <numeric>
 
-#define DEBUG_TYPE "xegpu-to-xevm"
-
 namespace mlir {
 #define GEN_PASS_DEF_CONVERTXEGPUTOXEVMPASS
 #include "mlir/Conversion/Passes.h.inc"
@@ -519,9 +517,6 @@ class CreateMemDescOpPattern final
   LogicalResult
   matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    // DEBUG: Print operation and types
-    LLVM_DEBUG(llvm::dbgs()
-               << "[XeGPUToXeVM] Lowering CreateMemDescOp: " << op << "\n");
     TypedValue<MemRefType> src = op.getSource();
     auto resTy = cast<xegpu::MemDescType>(op.getResult().getType());
 
@@ -529,19 +524,10 @@ class CreateMemDescOpPattern final
     // memory space
     auto newResTy = getTypeConverter()->convertType<MemRefType>(resTy);
 
-    LLVM_DEBUG(llvm::dbgs()
-               << "[XeGPUToXeVM] Source MemRefType: " << src.getType() << "\n");
-    LLVM_DEBUG(llvm::dbgs()
-               << "[XeGPUToXeVM] Result MemDescType: " << resTy << "\n");
-    LLVM_DEBUG(llvm::dbgs()
-               << "[XeGPUToXeVM] Converted MemRefType: " << newResTy << "\n");
     Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0);
     auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy,
                                          Value(src), zero, ValueRange());
     rewriter.replaceOp(op, viewOp);
-    LLVM_DEBUG(
-        llvm::dbgs()
-        << "[XeGPUToXeVM] Replaced CreateMemDescOp with memref::ViewOp\n");
     return success();
   }
 };
@@ -635,16 +621,33 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
       // if the attribute 'subgroup_block_io' is set to true, it lowers to
       // xevm.blockload
       auto subgroupBlockIoAttr = op.getSubgroupBlockIoAttr();
-      bool subgroup_block_io =
-          subgroupBlockIoAttr && cast<BoolAttr>(subgroupBlockIoAttr).getValue();
+      bool subgroup_block_io = static_cast<bool>(subgroupBlockIoAttr);
+
+      // BlockLoadOp only supports integer types, so we need to bitcast
+      // Get integer type with matching bit width
+      Type elemTy = valOrResVecTy.getElementType();
+      int64_t bitWidth = elemTy.getIntOrFloatBitWidth();
+      Type intElemTy = rewriter.getIntegerType(bitWidth);
+      VectorType intVecTy =
+          VectorType::get(valOrResVecTy.getShape(), intElemTy);
+
       if (subgroup_block_io) {
         if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
-          Value loadOp = xevm::BlockLoadOp::create(rewriter, loc, valOrResVecTy,
-                                                   basePtrLLVM);
+          Value loadOp =
+              xevm::BlockLoadOp::create(rewriter, loc, intVecTy, basePtrLLVM);
+          if (intVecTy != valOrResVecTy) {
+            loadOp =
+                vector::BitCastOp::create(rewriter, loc, valOrResVecTy, loadOp);
+          }
           rewriter.replaceOp(op, loadOp);
         } else {
-          xevm::BlockStoreOp::create(rewriter, loc, basePtrLLVM,
-                                     adaptor.getData(), nullptr);
+          Value dataToStore = adaptor.getData();
+          if (valOrResVecTy != intVecTy) {
+            dataToStore =
+                vector::BitCastOp::create(rewriter, loc, intVecTy, dataToStore);
+          }
+          xevm::BlockStoreOp::create(rewriter, loc, basePtrLLVM, dataToStore,
+                                     nullptr);
           rewriter.eraseOp(op);
         }
       } else {
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 26f2f691ab860..cccc8fab4adbc 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -37,8 +37,6 @@ void XeGPUDialect::initialize() {
       >();
 }
 
-#define DEBUG_TYPE "xegpu"
-
 /// Generates instructions to compute offsets for a subgroup identified by
 /// its multidimensional indices (sgId), using the specified subgroup layout
 /// (sgLayout), subgroup data dimensions (sizePerSg), and the overall data
@@ -788,30 +786,7 @@ SmallVector<int64_t> MemDescType::getStrides() {
     strides.push_back(cast<IntegerAttr>(attr).getInt());
   }
 
-  llvm::dbgs() << "DEBUG: matrixShape = [";
-  for (size_t i = 0; i < matrixShape.size(); ++i) {
-    llvm::dbgs() << matrixShape[i];
-    if (i < matrixShape.size() - 1)
-      llvm::dbgs() << ", ";
-  }
-  llvm::dbgs() << "]\n";
-
-  llvm::dbgs() << "DEBUG: strides = [";
-  for (size_t i = 0; i < strides.size(); ++i) {
-    llvm::dbgs() << strides[i];
-    if (i < strides.size() - 1)
-      llvm::dbgs() << ", ";
-  }
-  llvm::dbgs() << "]\n";
-
   SmallVector<int64_t> innerBlkShape = getBlockSize();
-  llvm::dbgs() << "DEBUG: innerBlkShape = [";
-  for (size_t i = 0; i < innerBlkShape.size(); ++i) {
-    llvm::dbgs() << innerBlkShape[i];
-    if (i < innerBlkShape.size() - 1)
-      llvm::dbgs() << ", ";
-  }
-  llvm::dbgs() << "]\n";
 
   // get perm from FCD to LCD
   // perm[i] = the dim with i-th smallest stride
@@ -819,25 +794,13 @@ SmallVector<int64_t> MemDescType::getStrides() {
       llvm::to_vector<4>(llvm::seq<int>(0, strides.size()));
   llvm::sort(perm, [&](int a, int b) { return strides[a] < strides[b]; });
 
-  llvm::dbgs() << "DEBUG: perm = [";
-  for (size_t i = 0; i < perm.size(); ++i) {
-    llvm::dbgs() << perm[i];
-    if (i < perm.size() - 1)
-      llvm::dbgs() << ", ";
-  }
-  llvm::dbgs() << "]\n";
-
   assert(strides[perm[0]] == 1 && "inner most dim must have stride 1");
 
-  SmallVector<int64_t> innerBlkStride = computeStrides(innerBlkShape);
-
-  llvm::dbgs() << "DEBUG: innerBlkStride = [";
-  for (size_t i = 0; i < innerBlkStride.size(); ++i) {
-    llvm::dbgs() << innerBlkStride[i];
-    if (i < innerBlkStride.size() - 1)
-      llvm::dbgs() << ", ";
-  }
-  llvm::dbgs() << "]\n";
+  SmallVector<int64_t> innerBlkStride(innerBlkShape.size());
+  innerBlkStride[perm[0]] = 1;
+  for (size_t i = 1; i < perm.size(); ++i)
+    innerBlkStride[perm[i]] =
+        innerBlkStride[perm[i - 1]] * innerBlkShape[perm[i - 1]];
 
   // compute the original matrix shape using the stride info
   // and compute the number of blocks in each dimension
@@ -850,28 +813,10 @@ SmallVector<int64_t> MemDescType::getStrides() {
     BlkShapeOrig[perm[i]] = matrixShapeOrig[perm[i]] / innerBlkShape[perm[i]];
   }
 
-  llvm::dbgs() << "DEBUG: matrixShapeOrig = [";
-  for (size_t i = 0; i < matrixShapeOrig.size(); ++i) {
-    llvm::dbgs() << matrixShapeOrig[i];
-    if (i < matrixShapeOrig.size() - 1)
-      llvm::dbgs() << ", ";
-  }
-  llvm::dbgs() << "]\n";
-
-  llvm::dbgs() << "DEBUG: BlkShapeOrig = [";
-  for (size_t i = 0; i < BlkShapeOrig.size(); ++i) {
-    llvm::dbgs() << BlkShapeOrig[i];
-    if (i < BlkShapeOrig.size() - 1)
-      llvm::dbgs() << ", ";
-  }
-  llvm::dbgs() << "]\n";
-
   int64_t innerBlkSize = 1;
   for (auto s : innerBlkShape)
     innerBlkSize *= s;
 
-  llvm::dbgs() << "DEBUG: innerBlkSize = " << innerBlkSize << "\n";
-
   SmallVector<int64_t> outerBlkStride(matrixShape.size());
   outerBlkStride[perm[0]] = innerBlkSize;
   for (size_t i = 0; i < perm.size() - 1; ++i) {
@@ -879,27 +824,11 @@ SmallVector<int64_t> MemDescType::getStrides() {
         outerBlkStride[perm[i]] * BlkShapeOrig[perm[i]];
   }
 
-  llvm::dbgs() << "DEBUG: outerBlkStride = [";
-  for (size_t i = 0; i < outerBlkStride.size(); ++i) {
-    llvm::dbgs() << outerBlkStride[i];
-    if (i < outerBlkStride.size() - 1)
-      llvm::dbgs() << ", ";
-  }
-  llvm::dbgs() << "]\n";
-
   // combine the inner and outer strides
   SmallVector<int64_t> blockedStrides;
   blockedStrides.append(outerBlkStride.begin(), outerBlkStride.end());
   blockedStrides.append(innerBlkStride.begin(), innerBlkStride.end());
 
-  llvm::dbgs() << "DEBUG: blockedStrides = [";
-  for (size_t i = 0; i < blockedStrides.size(); ++i) {
-    llvm::dbgs() << blockedStrides[i];
-    if (i < blockedStrides.size() - 1)
-      llvm::dbgs() << ", ";
-  }
-  llvm::dbgs() << "]\n";
-
   return blockedStrides;
 }
 
@@ -911,12 +840,6 @@ Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc,
   SmallVector<int64_t> blockShape = getBlockSize();
   SmallVector<int64_t> strides = getStrides();
 
-  LLVM_DEBUG(llvm::dbgs() << "getLinearOffsets: blockShape=[";
-             llvm::interleaveComma(blockShape, llvm::dbgs());
-             llvm::dbgs() << "], strides=[";
-             llvm::interleaveComma(strides, llvm::dbgs());
-             llvm::dbgs() << "]\n");
-
   // blockshape equal to matrixshape means no blocking
   if (llvm::equal(blockShape, matrixShape)) {
     // remove the outer dims from strides
@@ -937,8 +860,6 @@ Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc,
     blockedOffsets.append(rems.begin(), rems.end());
 
     offsets = blockedOffsets;
-    LLVM_DEBUG(llvm::dbgs() << "getLinearOffsets: blocked offsets size="
-                            << offsets.size() << "\n");
   }
 
   // Start with initial value as matrix descriptor's base offset.
@@ -949,9 +870,6 @@ Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc,
     linearOffset = arith::AddIOp::create(builder, loc, mulVal, linearOffset);
   }
 
-  LLVM_DEBUG(llvm::dbgs() << "getLinearOffsets: final linearOffset="
-                          << linearOffset << "\n");
-
   return linearOffset;
 }
 
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
index 372f477219817..3713635a1cc71 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
@@ -55,7 +55,7 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
   //CHECK-LABEL: load_store_matrix_5
   gpu.func @load_store_matrix_5(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
     %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
-    //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> vector<8xf16>
+    //CHECK: xevm.blockload {{.*}} : (!llvm.ptr<3>) -> vector<8xi16> 
     %c16 = arith.constant 16 : index
     %c48 = arith.constant 48 : index
     %1 = xegpu.load_matrix %0[%c16, %c48] {subgroup_block_io}: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> vector<8xf16>
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 228ef69d9a478..bef45438c944e 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -858,7 +858,7 @@ func.func @load_mem_desc_mismatch_element_type(%arg0: !xegpu.mem_desc<16x64xf16>
 
 // -----
 func.func @load_mem_desc_invalid_result_size(%arg0: !xegpu.mem_desc<16x64xf16>) {
-  // expected-error at +1 {{result shape must not exceed mem_desc shape}}
+  // expected-error at +1 {{data shape must not exceed mem_desc shape}}
   %data = xegpu.load_matrix %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> vector<32x16xf16>
   return
 }

>From bbd43d089096c8c66507c59c7df0c42d2806bcc0 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 10 Oct 2025 00:20:47 +0000
Subject: [PATCH 05/12] polish tests

---
 .../XeGPUToXeVM/loadstore_matrix.mlir         | 154 +++++++++++++++++-
 mlir/test/Dialect/XeGPU/invalid.mlir          |  29 ++++
 2 files changed, 174 insertions(+), 9 deletions(-)

diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
index 3713635a1cc71..6302758195e51 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
@@ -1,64 +1,200 @@
-// RUN: mlir-opt  -split-input-file -convert-xegpu-to-xevm --cse --canonicalize %s | FileCheck %s
+// RUN: mlir-opt  -split-input-file -convert-xegpu-to-xevm -cse %s | FileCheck %s
 
 gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
 
-  // e.g. for mem_desc<32x32xf16, @strides=[1, 16]>
+ // e.g. for mem_desc<32x32xf16, @strides=[1, 16]>
   // its memory layout tuple is (blocked shape = [1,1,32,32],strides=[1024,1024,32,1])
   //CHECK-LABEL: load_store_matrix_1
   gpu.func @load_store_matrix_1(%arg0: memref<4096xi8, 3>) -> f32 {
     %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf32>
+
+    //CHECK: %[[TID:.*]] = gpu.thread_id x
+    //CHECK: %[[C1:.*]] = arith.constant 1 : index
+    //CHECK: %[[MUL1:.*]] = arith.muli %[[TID]], %[[C1]] : index
+    //CHECK: %[[C4:.*]] = arith.constant 4 : i64
+    //CHECK: %[[MUL2:.*]] = arith.muli {{.*}}, %[[C4]] : i64
     //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f32
+
     %tid_x = gpu.thread_id x
     %c0 = arith.constant 0 : index
     %1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> f32
+
+    //CHECK: llvm.store {{.*}}, {{.*}} : f32, !llvm.ptr<3>
+
+     xegpu.store_matrix %1, %0[%c0, %tid_x]: f32, !xegpu.mem_desc<32x32xf32>, index, index
+
     gpu.return %1: f32
   }
 
-  // e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 32]>
+// e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 32]>
   // its memory layout tuple is ([2,4,16,16],[256,512,1,16])
   //CHECK-LABEL: load_store_matrix_2
   gpu.func @load_store_matrix_2(%arg0: memref<4096xi8, 3>) -> f16 {
     %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>
-    //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f16
+    //CHECK: %[[c0:.*]] = arith.constant 0 : index
+    //CHECK: %[[tid_x:.*]] = gpu.thread_id x
+    //CHECK: %[[c13:.*]] = arith.constant 13 : index
+    //CHECK: %[[c16:.*]] = arith.constant 16 : index
+    //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c13]], %[[c16]] : index
+    //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c13]], %[[c16]] : index
+    //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
+    //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
+
+    //CHECK: %[[c256:.*]] = arith.constant 256 : index
+    //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index
+    //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
+    //CHECK: %[[c512:.*]] = arith.constant 512 : index
+    //CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c512]] : index
+    //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
+    //CHECK: %[[c1:.*]] = arith.constant 1 : index
+    //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c1]] : index
+    //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
+    //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16]] : index
+    //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
+
+    //CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> f16
+ 
+
     %tid_x = gpu.thread_id x
     %c13 = arith.constant 13 : index
     %1 = xegpu.load_matrix %0[%c13, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> f16
+
+    //CHECK: llvm.store %[[loaded]], {{.*}} : f16, !llvm.ptr<3>
+   
+    xegpu.store_matrix %1, %0[%c13, %tid_x]: f16, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index 
     gpu.return %1: f16
   }
 
+
   // e.g. for mem_desc<32x64xf16, @block=[16, 16]>
   // its memory layout tuple is ([2,4,16,16],[1024,256,16,1])
   //CHECK-LABEL: load_store_matrix_3
   gpu.func @load_store_matrix_3(%arg0: memref<4096xi8, 3>) -> f16 {
+    //CHECK: %[[c0:.*]] = arith.constant 0 : index
+    //CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3>
     %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
-    //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f16
+    
+    //CHECK: %[[tid_x:.*]] = gpu.thread_id x
+    //CHECK: %[[c19:.*]] = arith.constant 19 : index
     %tid_x = gpu.thread_id x
     %c19 = arith.constant 19: index
+    
+    //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index
+    //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i64
+    //CHECK: %[[c16:.*]] = arith.constant 16 : index
+    //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c19]], %[[c16]] : index
+    //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c19]], %[[c16]] : index
+    //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
+    //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
+    //CHECK: %[[c1024:.*]] = arith.constant 1024 : index
+    //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c1024]] : index
+    //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
+    //CHECK: %[[c256:.*]] = arith.constant 256 : index
+    //CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c256]] : index
+    //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
+    //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c16]] : index
+    //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
+    //CHECK: %[[c1:.*]] = arith.constant 1 : index
+    //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c1]] : index
+    //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
+
+    //CHECK: %[[loaded:.*]] = llvm.load {{.*}} : !llvm.ptr<3> -> f16
     %1 = xegpu.load_matrix %0[%c19, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> f16
+    
+    //CHECK: llvm.store %[[loaded]], {{.*}} : f16, !llvm.ptr<3>
+    xegpu.store_matrix %1, %0[%c19, %tid_x]:  f16, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index
+    
+    //CHECK: gpu.return %[[loaded]] : f16
     gpu.return %1: f16
   }
-  
-  // e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 16]>
+
+   // e.g. for mem_desc<32x64xf16, @block=[16, 16], @strides=[1, 16]>
   // its memory layout tuple is ([2,4,16,16],[256,512,1,16])
   //CHECK-LABEL: load_store_matrix_4
   gpu.func @load_store_matrix_4(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
     %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>
-    //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> vector<8xf16>
+
+    //CHECK: %[[c0:.*]] = arith.constant 0 : index
+    //CHECK: %[[tid_x:.*]] = gpu.thread_id x
+
+    //CHECK: %[[c16:.*]] = arith.constant 16 : index
+    //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c16]], %[[c16]] : index
+    //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c16]], %[[c16]] : index
+    //CHECK: %[[offsety_0:.*]] = arith.divsi %[[tid_x]], %[[c16]] : index
+    //CHECK: %[[offsety_1:.*]] = arith.remsi %[[tid_x]], %[[c16]] : index
+
+    //CHECK: %[[c256:.*]] = arith.constant 256 : index
+    //CHECK: %[[mul0:.*]] = arith.muli %[[offsetx_0]], %[[c256]] : index
+    //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
+    //CHECK: %[[c512:.*]] = arith.constant 512 : index
+    //CHECK: %[[mul1:.*]] = arith.muli %[[offsety_0]], %[[c512]] : index
+    //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
+    //CHECK: %[[c1:.*]] = arith.constant 1 : index
+    //CHECK: %[[mul2:.*]] = arith.muli %[[offsetx_1]], %[[c1]] : index
+    //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
+    //CHECK: %[[mul3:.*]] = arith.muli %[[offsety_1]], %[[c16]] : index
+    //CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
+
+    //CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> vector<8xf16>
+     
     %tid_x = gpu.thread_id x
     %c16 = arith.constant 16 : index
     %1 = xegpu.load_matrix %0[%c16, %tid_x] {vec_length = 8 : i32, vec_direction = #xegpu.matrix_access_direction<col>}: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> vector<8xf16>
+
+    //CHECK: llvm.store %[[loaded]], {{.*}} : vector<8xf16>, !llvm.ptr<3>
+    xegpu.store_matrix %1, %0[%c16, %tid_x] {vec_length = 8 : i32, vec_direction = #xegpu.matrix_access_direction<col>}: vector<8xf16>, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index
+
     gpu.return %1: vector<8xf16>
   }
 
+ 
   // e.g. for mem_desc<32x64xf16, @block=[16, 16]>
   // its memory layout tuple is ([2,4,16,16],[1024,256,16,1])
   //CHECK-LABEL: load_store_matrix_5
   gpu.func @load_store_matrix_5(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
+    //CHECK: %[[c0:.*]] = arith.constant 0 : index
+    //CHECK: %[[view:.*]] = memref.view %arg0[%[[c0]]][] : memref<4096xi8, 3> to memref<2048xf16, 3>
+ 
     %0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
-    //CHECK: xevm.blockload {{.*}} : (!llvm.ptr<3>) -> vector<8xi16> 
+ 
+    //CHECK: %[[c16:.*]] = arith.constant 16 : index
+    //CHECK: %[[c48:.*]] = arith.constant 48 : index
+  
     %c16 = arith.constant 16 : index
     %c48 = arith.constant 48 : index
+
+    //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index
+    //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i64
+    //CHECK: %[[offset0:.*]] = arith.divsi %[[c16]], %[[c16]] : index
+    //CHECK: %[[offset1:.*]] = arith.remsi %[[c16]], %[[c16]] : index
+    //CHECK: %[[offset2:.*]] = arith.divsi %[[c48]], %[[c16]] : index
+    //CHECK: %[[offset3:.*]] = arith.remsi %[[c48]], %[[c16]] : index
+    //CHECK: %[[c1024:.*]] = arith.constant 1024 : index
+    //CHECK: %[[mul0:.*]] = arith.muli %[[offset0]], %[[c1024]] : index
+    //CHECK: %[[add0:.*]] = arith.addi %[[mul0]], %[[c0]] : index
+    //CHECK: %[[c256:.*]] = arith.constant 256 : index
+    //CHECK: %[[mul1:.*]] = arith.muli %[[offset2]], %[[c256]] : index
+    //CHECK: %[[add1:.*]] = arith.addi %[[mul1]], %[[add0]] : index
+    //CHECK: %[[mul2:.*]] = arith.muli %[[offset1]], %[[c16]] : index
+    //CHECK: %[[add2:.*]] = arith.addi %[[mul2]], %[[add1]] : index
+    //CHECK: %[[c1:.*]] = arith.constant 1 : index
+    //CHECK: %[[mul3:.*]] = arith.muli %[[offset3]], %[[c1]] : index
+    //CHECK: %[[linearOffset:.*]] = arith.addi %[[mul3]], %[[add2]] : index
+    //CHECK: %[[linearOffsetI64:.*]] = arith.index_castui %[[linearOffset]] : index to i64
+    //CHECK: %[[c2:.*]] = arith.constant 2 : i64
+    //CHECK: %[[byteOffset:.*]] = arith.muli %[[linearOffsetI64]], %[[c2]] : i64
+    //CHECK: %[[finalPtr:.*]] = arith.addi %[[basePtrI64]], %[[byteOffset]] : i64
+    //CHECK: %[[ptr:.*]] = llvm.inttoptr %[[finalPtr]] : i64 to !llvm.ptr<3>
+    //CHECK: %[[loadedI16:.*]] = xevm.blockload %[[ptr]] : (!llvm.ptr<3>) -> vector<8xi16>
+    //CHECK: %[[loaded:.*]] = vector.bitcast %[[loadedI16]] : vector<8xi16> to vector<8xf16>
+
     %1 = xegpu.load_matrix %0[%c16, %c48] {subgroup_block_io}: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> vector<8xf16>
+
+    //CHECK: %[[storeDataI16:.*]] = vector.bitcast %[[loaded]] : vector<8xf16> to vector<8xi16>
+    //CHECK: xevm.blockstore %[[ptr]], %[[storeDataI16]] : (!llvm.ptr<3>, vector<8xi16>) 
+
+    xegpu.store_matrix %1, %0[%c16, %c48] {subgroup_block_io}: vector<8xf16>, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index
+
     gpu.return %1: vector<8xf16>
   }
 
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index bef45438c944e..fee3136195e1d 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -870,6 +870,21 @@ func.func @load_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>) {
   return
 }
 
+// -----
+func.func @load_mem_desc_invalid_attr1(%arg0: !xegpu.mem_desc<16x64xf16>) {
+  // expected-error at +1 {{vec_length, vec_direction and subgroup_block_io are only allowed when result is a 1D VectorType.}}
+  %data1 = xegpu.load_matrix %arg0[8, 8]<{vec_length = 8 : i32, vec_direction = #xegpu.matrix_access_direction<col>}>: !xegpu.mem_desc<16x64xf16> -> vector<16x16xf16>
+  return
+}
+
+// -----
+func.func @load_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>) {
+  // expected-error at +1 {{vec_length, vec_direction and subgroup_block_io are only allowed when result is a 1D VectorType.}}
+  %data2 = xegpu.load_matrix %arg0[8, 8] <{subgroup_block_io}>: !xegpu.mem_desc<16x64xf16> -> vector<16x16xf16>
+  return
+}
+
+
 // -----
 func.func @store_mem_desc_mismatch_element_type(%arg0: !xegpu.mem_desc<16x64xf16>, %arg1: vector<16x16xf32>) {
   // expected-error at +1 {{failed to verify that all of {mem_desc, data} have same element type}}
@@ -891,6 +906,20 @@ func.func @store_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>, %arg1: ve
   return
 }
 
+// -----
+func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data: vector<16x16xf16>) {
+  // expected-error at +1 {{vec_length, vec_direction and subgroup_block_io are only allowed when result is a 1D VectorType.}}
+  xegpu.store_matrix %data,  %arg0[8, 8] <{subgroup_block_io}>: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
+  return
+}
+
+// -----
+func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data: vector<16x16xf16>) {
+  // expected-error at +1 {{vec_length, vec_direction and subgroup_block_io are only allowed when result is a 1D VectorType.}}
+  xegpu.store_matrix %data,  %arg0[8, 8] <{subgroup_block_io}>: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
+  return
+}
+
 // -----
 func.func @mem_desc_subview_size_mismatch(%arg0: !xegpu.mem_desc<16x64xf16>) {
   // expected-error at +1 {{result shape must not exceed source shape}}

>From 034476186425dde826929584ae36f95fa7263fd8 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 10 Oct 2025 06:19:21 +0000
Subject: [PATCH 06/12] fix minor issues

---
 mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp | 13 +++++--------
 1 file changed, 5 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 67e8246e5536a..05f26354e5a2a 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -61,10 +61,8 @@ static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
     return static_cast<int>(xevm::AddrSpace::GLOBAL);
   case xegpu::MemorySpace::SLM:
     return static_cast<int>(xevm::AddrSpace::SHARED);
-  default:
-    llvm_unreachable("Unknown XeGPU memory space");
-    return static_cast<int>(xevm::AddrSpace::GLOBAL);
   }
+  llvm_unreachable("Unknown XeGPU memory space");
 }
 
 // Get same bitwidth flat vector type of new element type.
@@ -186,8 +184,9 @@ class CreateNdDescToXeVMPattern
     SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
     // Descriptor shape is expected to be 2D.
     int64_t rank = mixedSizes.size();
-    if (rank != 2)
+    if (rank != 2) {
       return rewriter.notifyMatchFailure(op, "Expected 2D shape.");
+    }
     auto sourceTy = source.getType();
     auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
     // If source is a memref, we need to extract the aligned pointer as index.
@@ -612,8 +611,7 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
             LLVM::LoadOp::create(rewriter, loc, scalarTy, basePtrLLVM);
         rewriter.replaceOp(op, loadOp);
       } else {
-        auto storeOp = LLVM::StoreOp::create(rewriter, loc, adaptor.getData(),
-                                             basePtrLLVM);
+        LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM);
         rewriter.eraseOp(op);
       }
       return success();
@@ -680,8 +678,7 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
               LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM);
           rewriter.replaceOp(op, loadOp);
         } else {
-          auto storeOp = LLVM::StoreOp::create(rewriter, loc, adaptor.getData(),
-                                               basePtrLLVM);
+          LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM);
           rewriter.eraseOp(op);
         }
       }

>From 966525b19652cb75c20722dfa9c22bb74d43a87b Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Mon, 13 Oct 2025 18:04:56 +0000
Subject: [PATCH 07/12] remove vector direction and lenght attirbutes

---
 .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td       | 18 ------------
 .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td |  4 ---
 .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp    | 18 ++----------
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        | 29 +++++++------------
 .../XeGPUToXeVM/loadstore_matrix.mlir         |  4 +--
 mlir/test/Dialect/XeGPU/invalid.mlir          | 13 ++-------
 mlir/test/Dialect/XeGPU/ops.mlir              |  8 ++---
 7 files changed, 22 insertions(+), 72 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 601e966b49890..2efd575a652db 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -724,22 +724,4 @@ def XeGPU_MemLayoutAttr : XeGPUAttr<"MemLayout", "mem_layout"> {
 
 }
 
-def RowOriented : I32EnumAttrCase<"ROW", 0, "row">;
-def ColOriented : I32EnumAttrCase<"COL", 1, "col">;
-def MatrixAccessDirection : 
-    I32EnumAttr<"MatrixAccessDirection", 
-    "Matrix elements/vectors can have row or column direction", [
-    RowOriented, ColOriented
-]> {
-  let genSpecializedAttr = 0;
-  let cppNamespace = "::mlir::xegpu";
-}
-def MatrixAccessDirectionAttr : 
-    EnumAttr<XeGPU_Dialect, 
-    MatrixAccessDirection, 
-    "matrix_access_direction">{
-  let summary = [{Describe the direction of memory access for load_matrix and store_matrix.}];
-  let assemblyFormat = "`<` $value `>`";
-}
-
 #endif // MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 044a8ef22d891..f41f9e887cff7 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -1302,8 +1302,6 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
   let arguments = (ins XeGPU_MemDesc:$mem_desc,
     Variadic<Index>: $offsets,
     DenseI64ArrayAttr: $const_offsets,
-    OptionalAttr<I32Attr>:$vec_length,
-    OptionalAttr<MatrixAccessDirectionAttr>:$vec_direction,
     OptionalAttr<UnitAttr>:$subgroup_block_io,
     OptionalAttr<DistributeLayoutAttr>:$layout
   );
@@ -1355,8 +1353,6 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
     XeGPU_MemDesc:$mem_desc,
     Variadic<Index>: $offsets,
     DenseI64ArrayAttr: $const_offsets,
-    OptionalAttr<I32Attr>:$vec_length,
-    OptionalAttr<MatrixAccessDirectionAttr>:$vec_direction,
     OptionalAttr<UnitAttr>:$subgroup_block_io,
     OptionalAttr<DistributeLayoutAttr>:$layout
   );
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 05f26354e5a2a..2ff2c98d291d2 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -184,9 +184,9 @@ class CreateNdDescToXeVMPattern
     SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
     // Descriptor shape is expected to be 2D.
     int64_t rank = mixedSizes.size();
-    if (rank != 2) {
+    if (rank != 2)
       return rewriter.notifyMatchFailure(op, "Expected 2D shape.");
-    }
+
     auto sourceTy = source.getType();
     auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
     // If source is a memref, we need to extract the aligned pointer as index.
@@ -658,20 +658,6 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
           return rewriter.notifyMatchFailure(
               op, "The lowering is specific to pvc or bmg.");
         }
-        xegpu::MatrixAccessDirectionAttr vecDirection =
-            op.getVecDirectionAttr();
-        if (vecDirection &&
-            vecDirection.getValue() == xegpu::MatrixAccessDirection::COL &&
-            !mdescTy.isColMajor())
-          return rewriter.notifyMatchFailure(
-              op, "mem_desc should be column major when "
-                  "vec_direction is COLUMN for 1D result.");
-        if (vecDirection &&
-            vecDirection.getValue() == xegpu::MatrixAccessDirection::ROW &&
-            mdescTy.isColMajor())
-          return rewriter.notifyMatchFailure(
-              op, "mem_desc should be row major when "
-                  "vec_direction is ROW for 1D result.");
 
         if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
           Value loadOp =
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 0bc7b3f06ec53..8d86e78fcbf4f 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -173,17 +173,18 @@ isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy,
   return success();
 }
 
-LogicalResult IsValidStoreMatrixParams(
-    VectorType dataTy, MemDescType mdescTy, UnitAttr subgroup_block_io,
-    MatrixAccessDirectionAttr vecDirection, IntegerAttr vecLength,
-    function_ref<InFlightDiagnostic()> emitError) {
-
-  if (!dataTy)
-    if (subgroup_block_io || vecDirection || vecLength)
-      return emitError() << "vec_length, vec_direction and subgroup_block_io "
+LogicalResult
+IsValidStoreMatrixParams(VectorType dataTy, MemDescType mdescTy,
+                         UnitAttr subgroup_block_io,
+                         function_ref<InFlightDiagnostic()> emitError) {
+
+  if (!dataTy) {
+    if (subgroup_block_io)
+      return emitError() << "subgroup_block_io "
                             "are only allowed when result is a 1D VectorType.";
     else
       return success();
+  }
 
   if (mdescTy.getRank() != 2)
     return emitError() << "mem_desc must be 2D.";
@@ -192,8 +193,8 @@ LogicalResult IsValidStoreMatrixParams(
   ArrayRef<int64_t> mdescShape = mdescTy.getShape();
 
   if (dataShape.size() == 2) {
-    if (subgroup_block_io || vecDirection || vecLength)
-      return emitError() << "vec_length, vec_direction and subgroup_block_io "
+    if (subgroup_block_io)
+      return emitError() << "subgroup_block_io "
                             "are only allowed when result is a 1D VectorType.";
     if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
                      [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
@@ -1097,7 +1098,6 @@ void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res,
   // Call the generated builder with all parameters (including optional ones as
   // nullptr/empty)
   build(builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr,
-        /*vec_length=*/nullptr, /*vec_direction=*/nullptr,
         /*subgroup_block_io=*/nullptr, layout);
 }
 
@@ -1105,12 +1105,9 @@ LogicalResult LoadMatrixOp::verify() {
 
   auto resTy = dyn_cast<VectorType>(getRes().getType());
   UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
-  MatrixAccessDirectionAttr vecDirection = getVecDirectionAttr();
-  IntegerAttr vecLength = getVecLengthAttr();
   MemDescType mdescTy = getMemDesc().getType();
 
   return IsValidStoreMatrixParams(resTy, mdescTy, subgroup_block_io,
-                                  vecDirection, vecLength,
                                   [&]() { return emitError(); });
 }
 
@@ -1126,7 +1123,6 @@ void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data,
   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
   auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
   build(builder, state, data, memDesc, dynamicOffsets, staticOffsetsAttr,
-        /*vec_length=*/nullptr, /*vec_direction=*/nullptr,
         /*subgroup_block_io=*/nullptr, layout);
 }
 
@@ -1134,11 +1130,8 @@ LogicalResult StoreMatrixOp::verify() {
 
   auto dataTy = dyn_cast<VectorType>(getData().getType());
   UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
-  MatrixAccessDirectionAttr vecDirection = getVecDirectionAttr();
-  IntegerAttr vecLength = getVecLengthAttr();
   MemDescType mdescTy = getMemDesc().getType();
   return IsValidStoreMatrixParams(dataTy, mdescTy, subgroup_block_io,
-                                  vecDirection, vecLength,
                                   [&]() { return emitError(); });
 }
 
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
index 6302758195e51..ebb3c2b2b6a83 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
@@ -139,10 +139,10 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
      
     %tid_x = gpu.thread_id x
     %c16 = arith.constant 16 : index
-    %1 = xegpu.load_matrix %0[%c16, %tid_x] {vec_length = 8 : i32, vec_direction = #xegpu.matrix_access_direction<col>}: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> vector<8xf16>
+    %1 = xegpu.load_matrix %0[%c16, %tid_x] : !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> vector<8xf16>
 
     //CHECK: llvm.store %[[loaded]], {{.*}} : vector<8xf16>, !llvm.ptr<3>
-    xegpu.store_matrix %1, %0[%c16, %tid_x] {vec_length = 8 : i32, vec_direction = #xegpu.matrix_access_direction<col>}: vector<8xf16>, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index
+    xegpu.store_matrix %1, %0[%c16, %tid_x] : vector<8xf16>, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index
 
     gpu.return %1: vector<8xf16>
   }
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index fee3136195e1d..6062eba709b88 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -870,16 +870,9 @@ func.func @load_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>) {
   return
 }
 
-// -----
-func.func @load_mem_desc_invalid_attr1(%arg0: !xegpu.mem_desc<16x64xf16>) {
-  // expected-error at +1 {{vec_length, vec_direction and subgroup_block_io are only allowed when result is a 1D VectorType.}}
-  %data1 = xegpu.load_matrix %arg0[8, 8]<{vec_length = 8 : i32, vec_direction = #xegpu.matrix_access_direction<col>}>: !xegpu.mem_desc<16x64xf16> -> vector<16x16xf16>
-  return
-}
-
 // -----
 func.func @load_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>) {
-  // expected-error at +1 {{vec_length, vec_direction and subgroup_block_io are only allowed when result is a 1D VectorType.}}
+  // expected-error at +1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}}
   %data2 = xegpu.load_matrix %arg0[8, 8] <{subgroup_block_io}>: !xegpu.mem_desc<16x64xf16> -> vector<16x16xf16>
   return
 }
@@ -908,14 +901,14 @@ func.func @store_mem_desc_invalid_rank(%arg0: !xegpu.mem_desc<64xf16>, %arg1: ve
 
 // -----
 func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data: vector<16x16xf16>) {
-  // expected-error at +1 {{vec_length, vec_direction and subgroup_block_io are only allowed when result is a 1D VectorType.}}
+  // expected-error at +1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}}
   xegpu.store_matrix %data,  %arg0[8, 8] <{subgroup_block_io}>: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
   return
 }
 
 // -----
 func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data: vector<16x16xf16>) {
-  // expected-error at +1 {{vec_length, vec_direction and subgroup_block_io are only allowed when result is a 1D VectorType.}}
+  // expected-error at +1 {{subgroup_block_io are only allowed when result is a 1D VectorType.}}
   xegpu.store_matrix %data,  %arg0[8, 8] <{subgroup_block_io}>: vector<16x16xf16>, !xegpu.mem_desc<16x64xf16>
   return
 }
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index eb5d653be8b9c..f1f5f86d33bc0 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -855,8 +855,8 @@ gpu.func @simt_load_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16, #
 
 // CHECK: gpu.func @simt_load_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>)
 gpu.func @simt_load_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>) {
-  // CHECK: xegpu.load_matrix [[ARG0]][8, 8] <{vec_direction = #xegpu.matrix_access_direction<col>, vec_length = 8 : i32}>: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8xf16> 
-  %data = xegpu.load_matrix %arg0[8, 8] <{vec_direction = #xegpu.matrix_access_direction<col>, vec_length = 8 : i32}>: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8xf16>
+  // CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8xf16> 
+  %data = xegpu.load_matrix %arg0[8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> vector<8xf16>
   gpu.return
 }
 
@@ -890,8 +890,8 @@ gpu.func @simt_store_matrix_subgroup_block_io(%arg0: !xegpu.mem_desc<16x64xf16,
 
 // CHECK: gpu.func @simt_store_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, %arg1: vector<8xf16>) {
 gpu.func @simt_store_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>, %arg1: vector<8xf16>) {
-  // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 8] <{vec_direction = #xegpu.matrix_access_direction<col>, vec_length = 8 : i32}>: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> 
-  xegpu.store_matrix %arg1, %arg0[8, 8] <{vec_length = 8 : i32, vec_direction = #xegpu.matrix_access_direction<col>}>: vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
+  // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 8] : vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> 
+  xegpu.store_matrix %arg1, %arg0[8, 8] : vector<8xf16>, !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>
   gpu.return
 }
 

>From 272f51213290a1784ac8a44124fbb38b67c9b1c3 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Mon, 13 Oct 2025 23:15:41 +0000
Subject: [PATCH 08/12] address comments

---
 .../mlir/Dialect/XeGPU/IR/XeGPUTypes.td       | 26 ++++++++++++-------
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    |  8 +++---
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        |  4 +--
 .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp  |  3 +++
 .../Transforms/XeGPUWgToSgDistribute.cpp      |  1 +
 .../XeGPUToXeVM/loadstore_matrix.mlir         |  2 +-
 6 files changed, 27 insertions(+), 17 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index c261fbb576642..99526159cd2e7 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -242,7 +242,6 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
       if (layout && layout.hasAttr("stride")) {
         return layout.getStrides();
       }
-
       // derive and return default strides
       SmallVector<int64_t> defaultStrides;
       llvm::append_range(defaultStrides, getShape().drop_front());
@@ -251,6 +250,15 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
       return builder.getI64ArrayAttr(defaultStrides);
     }
 
+    ArrayAttr getBlockAttr() {
+      auto layout = getMemLayout();
+      if (layout && layout.hasAttr("block")) {
+        return layout.getBlockAttr();
+      }
+      Builder builder(getContext());
+      return builder.getI64ArrayAttr({});
+    }
+
     /// Heuristic to determine if the MemDesc uses column-major layout,
     /// based on the rank and the value of the first stride dimension.
     bool isColMajor() {
@@ -261,16 +269,14 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
     // get the Blocking shape for a MemDescType, Which is represented
     // as an attribute in MemDescType. By default it is the shape
     // of the mdescTy
-    SmallVector<int64_t> getBlockSize() {
+    SmallVector<int64_t> getBlockShape() {
       SmallVector<int64_t> size(getShape());
-      MemLayoutAttr layout = getMemLayout();
-      if (layout && layout.hasAttr("block")) {
-        ArrayAttr attr = layout.getBlockAttr();
+      ArrayAttr blockAttr = getBlockAttr();
+      if (!blockAttr.empty()) {
         size.clear();
-        llvm::for_each(attr, [&](Attribute elem) {
-          if (auto intElem = dyn_cast<IntegerAttr>(elem))
-            size.push_back(intElem.getInt());
-        });
+        for (auto attr : blockAttr.getValue()) {
+          size.push_back(cast<IntegerAttr>(attr).getInt());
+        }
       }
       return size;
     }
@@ -289,7 +295,7 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
     // its memory layout tuple is ([2,32,16,8],[128,256,1,16])
     // for  mem_desc<256x32xf16, @block=[8, 16]> with default @stride[32, 1]
     // its memory layout tuple is ([32,2,8,16],[256,128,16,1])
-    SmallVector<int64_t> getStrides();
+    SmallVector<int64_t> getStrideShape();
     
     /// Generates instructions to compute the linearize offset 
     //  if the memory descriptor is blocked, it returns linearize offset based on the blocked layout
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index cccc8fab4adbc..78eee0102ba85 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -776,7 +776,7 @@ SmallVector<OpFoldResult> getBlockedOffsets(OpBuilder &builder, Location loc,
 }
 
 // Get strides as vector of integer for MemDesc.
-SmallVector<int64_t> MemDescType::getStrides() {
+SmallVector<int64_t> MemDescType::getStrideShape() {
 
   SmallVector<int64_t> matrixShape(getShape().begin(), getShape().end());
 
@@ -786,7 +786,7 @@ SmallVector<int64_t> MemDescType::getStrides() {
     strides.push_back(cast<IntegerAttr>(attr).getInt());
   }
 
-  SmallVector<int64_t> innerBlkShape = getBlockSize();
+  SmallVector<int64_t> innerBlkShape = getBlockShape();
 
   // get perm from FCD to LCD
   // perm[i] = the dim with i-th smallest stride
@@ -837,8 +837,8 @@ Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc,
                                     ArrayRef<OpFoldResult> offsets) {
 
   SmallVector<int64_t> matrixShape(getShape().begin(), getShape().end());
-  SmallVector<int64_t> blockShape = getBlockSize();
-  SmallVector<int64_t> strides = getStrides();
+  SmallVector<int64_t> blockShape = getBlockShape();
+  SmallVector<int64_t> strides = getStrideShape();
 
   // blockshape equal to matrixshape means no blocking
   if (llvm::equal(blockShape, matrixShape)) {
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 8d86e78fcbf4f..8c7a686b8ce0d 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -201,10 +201,10 @@ IsValidStoreMatrixParams(VectorType dataTy, MemDescType mdescTy,
       return emitError() << "data shape must not exceed mem_desc shape.";
   } else if (dataShape.size() == 1) {
 
-    SmallVector<int64_t> blockSize = mdescTy.getBlockSize();
+    SmallVector<int64_t> blockShape = mdescTy.getBlockShape();
     // if the subgroup_block_io attribute is set,  mdescTy must have block
     // attribute
-    if (subgroup_block_io && !blockSize.size())
+    if (subgroup_block_io && !blockShape.size())
       return emitError() << "mem_desc must have block attribute when "
                             "subgroup_block_io is set.";
     // if the subgroup_block_io attribute is set, the memdesc should be row
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 6d17b27849a43..aafa1b7deb84b 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -942,6 +942,8 @@ struct UnrollLoadMatrixOp : public UnrollPattern<xegpu::LoadMatrixOp> {
                                 PatternRewriter &rewriter) const override {
     Location loc = op.getLoc();
     VectorType valueTy = llvm::dyn_cast<VectorType>(op.getType());
+    assert(valueTy && "the value type must be vector type!");
+
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
     if (!targetShape || targetShape->size() != (size_t)valueTy.getRank())
       return failure();
@@ -985,6 +987,7 @@ struct UnrollStoreMatrixOp : public UnrollPattern<xegpu::StoreMatrixOp> {
 
     Location loc = op.getLoc();
     VectorType valueTy = llvm::dyn_cast<VectorType>(op.getData().getType());
+    assert(valueTy && "the value type must be vector type!");
     ArrayRef<int64_t> shape = valueTy.getShape();
     auto layout = dyn_cast<xegpu::LayoutAttr>(op.getLayoutAttr());
 
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index baee57c512ddf..31a967dcd04c7 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -992,6 +992,7 @@ struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> {
 
     ArrayRef<int64_t> wgShape = op.getDataShape();
     VectorType valueTy = llvm::dyn_cast<VectorType>(op.getRes().getType());
+    assert(valueTy && "the value type must be vector type!");
     Type elemTy = valueTy.getElementType();
 
     xegpu::DistributeLayoutAttr layout = op.getLayoutAttr();
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
index ebb3c2b2b6a83..df1433e7b98ae 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
@@ -198,4 +198,4 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
     gpu.return %1: vector<8xf16>
   }
 
-}
\ No newline at end of file
+}

>From b1857a275d7e30a55ac9b17b335f61f556b2e695 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 14 Oct 2025 01:05:09 +0000
Subject: [PATCH 09/12] address more comments

---
 .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp    | 134 ++++++++----------
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        |  14 +-
 .../XeGPUToXeVM/loadstore_matrix.mlir         |  18 +--
 3 files changed, 77 insertions(+), 89 deletions(-)

diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 2ff2c98d291d2..e5e797a1fa1c3 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -365,10 +365,11 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
 
 // Add a builder that creates
 // offset * elemByteSize + baseAddr
-static Value addOffset(ConversionPatternRewriter &rewriter, Location loc,
-                       Value baseAddr, Value offset, int64_t elemByteSize) {
+static Value addOffsetToBaseAddr(ConversionPatternRewriter &rewriter,
+                                 Location loc, Value baseAddr, Value offset,
+                                 int64_t elemByteSize) {
   Value byteSize = arith::ConstantIntOp::create(
-      rewriter, loc, rewriter.getI64Type(), elemByteSize);
+      rewriter, loc, baseAddr.getType(), elemByteSize);
   Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize);
   Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset);
   return newAddr;
@@ -443,7 +444,8 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
       // If offset is provided, we add them to the base pointer.
       // Offset is in number of elements, we need to multiply by
       // element byte size.
-      basePtrI64 = addOffset(rewriter, loc, basePtrI64, offset, elemByteSize);
+      basePtrI64 =
+          addOffsetToBaseAddr(rewriter, loc, basePtrI64, offset, elemByteSize);
     }
     // Convert base pointer (i64) to LLVM pointer type.
     Value basePtrLLVM =
@@ -516,7 +518,7 @@ class CreateMemDescOpPattern final
   LogicalResult
   matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    TypedValue<MemRefType> src = op.getSource();
+
     auto resTy = cast<xegpu::MemDescType>(op.getResult().getType());
 
     // Create the result MemRefType with the same shape, element type, and
@@ -525,7 +527,7 @@ class CreateMemDescOpPattern final
 
     Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0);
     auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy,
-                                         Value(src), zero, ValueRange());
+                                         op.getSource(), zero, ValueRange());
     rewriter.replaceOp(op, viewOp);
     return success();
   }
@@ -587,88 +589,74 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
     Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create(
         rewriter, loc, basePtrStruct);
 
-    // Convert base pointer (ptr) to i64
-    Value basePtrI64 = arith::IndexCastUIOp::create(
-        rewriter, loc, rewriter.getI64Type(), basePtrLLVM);
+    // Convert base pointer (ptr) to i32
+    Value basePtrI32 = arith::IndexCastUIOp::create(
+        rewriter, loc, rewriter.getI32Type(), basePtrLLVM);
 
     Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets);
     linearOffset = arith::IndexCastUIOp::create(
-        rewriter, loc, rewriter.getI64Type(), linearOffset);
-    basePtrI64 =
-        addOffset(rewriter, loc, basePtrI64, linearOffset, elemByteSize);
+        rewriter, loc, rewriter.getI32Type(), linearOffset);
+    basePtrI32 = addOffsetToBaseAddr(rewriter, loc, basePtrI32, linearOffset,
+                                     elemByteSize);
 
-    // convert base pointer (i64) to LLVM pointer type
+    // convert base pointer (i32) to LLVM pointer type
     basePtrLLVM =
-        LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
+        LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI32);
 
-    // if the size of valOrResVecTy is 1, it lowers to a scalar load/store
-    // operation. LLVM load/store does not support vector of size 1, so we need
-    // to handle this case separately.
-    if (valOrResVecTy.getNumElements() == 1) {
-      Type scalarTy = valOrResVecTy.getElementType();
-      if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
-        Value loadOp =
-            LLVM::LoadOp::create(rewriter, loc, scalarTy, basePtrLLVM);
-        rewriter.replaceOp(op, loadOp);
-      } else {
-        LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM);
-        rewriter.eraseOp(op);
-      }
-      return success();
-    } else {
+    if (op.getSubgroupBlockIoAttr()) {
       // if the attribute 'subgroup_block_io' is set to true, it lowers to
       // xevm.blockload
-      auto subgroupBlockIoAttr = op.getSubgroupBlockIoAttr();
-      bool subgroup_block_io = static_cast<bool>(subgroupBlockIoAttr);
-
-      // BlockLoadOp only supports integer types, so we need to bitcast
-      // Get integer type with matching bit width
-      Type elemTy = valOrResVecTy.getElementType();
-      int64_t bitWidth = elemTy.getIntOrFloatBitWidth();
-      Type intElemTy = rewriter.getIntegerType(bitWidth);
+
+      Type intElemTy = rewriter.getIntegerType(elemBitWidth);
       VectorType intVecTy =
           VectorType::get(valOrResVecTy.getShape(), intElemTy);
 
-      if (subgroup_block_io) {
-        if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
-          Value loadOp =
-              xevm::BlockLoadOp::create(rewriter, loc, intVecTy, basePtrLLVM);
-          if (intVecTy != valOrResVecTy) {
-            loadOp =
-                vector::BitCastOp::create(rewriter, loc, valOrResVecTy, loadOp);
-          }
-          rewriter.replaceOp(op, loadOp);
-        } else {
-          Value dataToStore = adaptor.getData();
-          if (valOrResVecTy != intVecTy) {
-            dataToStore =
-                vector::BitCastOp::create(rewriter, loc, intVecTy, dataToStore);
-          }
-          xevm::BlockStoreOp::create(rewriter, loc, basePtrLLVM, dataToStore,
-                                     nullptr);
-          rewriter.eraseOp(op);
+      if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
+        Value loadOp =
+            xevm::BlockLoadOp::create(rewriter, loc, intVecTy, basePtrLLVM);
+        if (intVecTy != valOrResVecTy) {
+          loadOp =
+              vector::BitCastOp::create(rewriter, loc, valOrResVecTy, loadOp);
         }
+        rewriter.replaceOp(op, loadOp);
       } else {
-        // if the result is 1D vector, if the vector direction is Column, then
-        // the
-        //  memory descriptor should be treated as column major
-        auto chipOpt = xegpu::getChipStr(op);
-        if (!chipOpt || (*chipOpt != "pvc" && *chipOpt != "bmg")) {
-          // the lowering only works for pvc and bmg
-          return rewriter.notifyMatchFailure(
-              op, "The lowering is specific to pvc or bmg.");
+        Value dataToStore = adaptor.getData();
+        if (valOrResVecTy != intVecTy) {
+          dataToStore =
+              vector::BitCastOp::create(rewriter, loc, intVecTy, dataToStore);
         }
+        xevm::BlockStoreOp::create(rewriter, loc, basePtrLLVM, dataToStore,
+                                   nullptr);
+        rewriter.eraseOp(op);
+      }
+      return success();
+    }
 
-        if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
-          Value loadOp =
-              LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM);
-          rewriter.replaceOp(op, loadOp);
-        } else {
-          LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM);
-          rewriter.eraseOp(op);
-        }
+    if (valOrResVecTy.getNumElements() >= 1) {
+      auto chipOpt = xegpu::getChipStr(op);
+      if (!chipOpt || (*chipOpt != "pvc" && *chipOpt != "bmg")) {
+        // the lowering for chunk load only works for pvc and bmg
+        return rewriter.notifyMatchFailure(
+            op, "The lowering is specific to pvc or bmg.");
       }
     }
+
+    if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
+      // if the size of valOrResVecTy is 1, it lowers to a scalar load/store
+      // operation. LLVM load/store does not support vector of size 1, so we
+      // need to handle this case separately.
+      auto scalarTy = valOrResVecTy.getElementType();
+      LLVM::LoadOp loadOp;
+      if (valOrResVecTy.getNumElements() == 1)
+        loadOp = LLVM::LoadOp::create(rewriter, loc, scalarTy, basePtrLLVM);
+      else
+        loadOp =
+            LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM);
+      rewriter.replaceOp(op, loadOp);
+    } else {
+      LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM);
+      rewriter.eraseOp(op);
+    }
     return success();
   }
 };
@@ -715,8 +703,8 @@ class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
                 op, "Expected element type bit width to be multiple of 8.");
           elemByteSize = elemBitWidth / 8;
         }
-        basePtrI64 =
-            addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize);
+        basePtrI64 = addOffsetToBaseAddr(rewriter, loc, basePtrI64, offsets,
+                                         elemByteSize);
       }
     }
     // Default memory space is global.
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 8c7a686b8ce0d..7108afffe99d5 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -174,9 +174,9 @@ isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy,
 }
 
 LogicalResult
-IsValidStoreMatrixParams(VectorType dataTy, MemDescType mdescTy,
-                         UnitAttr subgroup_block_io,
-                         function_ref<InFlightDiagnostic()> emitError) {
+IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
+                      UnitAttr subgroup_block_io,
+                      function_ref<InFlightDiagnostic()> emitError) {
 
   if (!dataTy) {
     if (subgroup_block_io)
@@ -1107,8 +1107,8 @@ LogicalResult LoadMatrixOp::verify() {
   UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
   MemDescType mdescTy = getMemDesc().getType();
 
-  return IsValidStoreMatrixParams(resTy, mdescTy, subgroup_block_io,
-                                  [&]() { return emitError(); });
+  return IsValidMatrixOpParams(resTy, mdescTy, subgroup_block_io,
+                               [&]() { return emitError(); });
 }
 
 //===----------------------------------------------------------------------===//
@@ -1131,8 +1131,8 @@ LogicalResult StoreMatrixOp::verify() {
   auto dataTy = dyn_cast<VectorType>(getData().getType());
   UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
   MemDescType mdescTy = getMemDesc().getType();
-  return IsValidStoreMatrixParams(dataTy, mdescTy, subgroup_block_io,
-                                  [&]() { return emitError(); });
+  return IsValidMatrixOpParams(dataTy, mdescTy, subgroup_block_io,
+                               [&]() { return emitError(); });
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
index df1433e7b98ae..d4cb493271d0d 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir
@@ -11,8 +11,8 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
     //CHECK: %[[TID:.*]] = gpu.thread_id x
     //CHECK: %[[C1:.*]] = arith.constant 1 : index
     //CHECK: %[[MUL1:.*]] = arith.muli %[[TID]], %[[C1]] : index
-    //CHECK: %[[C4:.*]] = arith.constant 4 : i64
-    //CHECK: %[[MUL2:.*]] = arith.muli {{.*}}, %[[C4]] : i64
+    //CHECK: %[[C4:.*]] = arith.constant 4 : i32
+    //CHECK: %[[MUL2:.*]] = arith.muli {{.*}}, %[[C4]] : i32
     //CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f32
 
     %tid_x = gpu.thread_id x
@@ -80,7 +80,7 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
     %c19 = arith.constant 19: index
     
     //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index
-    //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i64
+    //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
     //CHECK: %[[c16:.*]] = arith.constant 16 : index
     //CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c19]], %[[c16]] : index
     //CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c19]], %[[c16]] : index
@@ -164,7 +164,7 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
     %c48 = arith.constant 48 : index
 
     //CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index
-    //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i64
+    //CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
     //CHECK: %[[offset0:.*]] = arith.divsi %[[c16]], %[[c16]] : index
     //CHECK: %[[offset1:.*]] = arith.remsi %[[c16]], %[[c16]] : index
     //CHECK: %[[offset2:.*]] = arith.divsi %[[c48]], %[[c16]] : index
@@ -180,11 +180,11 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
     //CHECK: %[[c1:.*]] = arith.constant 1 : index
     //CHECK: %[[mul3:.*]] = arith.muli %[[offset3]], %[[c1]] : index
     //CHECK: %[[linearOffset:.*]] = arith.addi %[[mul3]], %[[add2]] : index
-    //CHECK: %[[linearOffsetI64:.*]] = arith.index_castui %[[linearOffset]] : index to i64
-    //CHECK: %[[c2:.*]] = arith.constant 2 : i64
-    //CHECK: %[[byteOffset:.*]] = arith.muli %[[linearOffsetI64]], %[[c2]] : i64
-    //CHECK: %[[finalPtr:.*]] = arith.addi %[[basePtrI64]], %[[byteOffset]] : i64
-    //CHECK: %[[ptr:.*]] = llvm.inttoptr %[[finalPtr]] : i64 to !llvm.ptr<3>
+    //CHECK: %[[linearOffsetI64:.*]] = arith.index_castui %[[linearOffset]] : index to i32
+    //CHECK: %[[c2:.*]] = arith.constant 2 : i32
+    //CHECK: %[[byteOffset:.*]] = arith.muli %[[linearOffsetI64]], %[[c2]] : i32
+    //CHECK: %[[finalPtr:.*]] = arith.addi %[[basePtrI64]], %[[byteOffset]] : i32
+    //CHECK: %[[ptr:.*]] = llvm.inttoptr %[[finalPtr]] : i32 to !llvm.ptr<3>
     //CHECK: %[[loadedI16:.*]] = xevm.blockload %[[ptr]] : (!llvm.ptr<3>) -> vector<8xi16>
     //CHECK: %[[loaded:.*]] = vector.bitcast %[[loadedI16]] : vector<8xi16> to vector<8xf16>
 

>From 7a63d93d076d8b90ff27e3e4f88b008780078f75 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 14 Oct 2025 21:24:07 +0000
Subject: [PATCH 10/12] address more feedback

---
 .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td       |  2 +-
 .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 37 -------------------
 .../mlir/Dialect/XeGPU/IR/XeGPUTypes.td       |  6 +--
 .../Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp    | 15 +-------
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    |  2 +-
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        | 34 -----------------
 mlir/test/Dialect/XeGPU/invalid.mlir          | 28 --------------
 mlir/test/Dialect/XeGPU/ops.mlir              | 21 -----------
 8 files changed, 6 insertions(+), 139 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index 2efd575a652db..19a52317956d2 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -712,7 +712,7 @@ def XeGPU_MemLayoutAttr : XeGPUAttr<"MemLayout", "mem_layout"> {
       return getAttrs().contains(name);
     }
 
-    ArrayAttr getStrides() {
+    ArrayAttr getStrideAttr() {
       return getAttrs().getAs<ArrayAttr>("stride");
     }
 
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index f41f9e887cff7..73b70da9642e4 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -1392,41 +1392,4 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
   let hasVerifier = 1;
 }
 
-def XeGPU_MemDescSubviewOp: XeGPU_Op<"mem_desc_subview",
-          [Pure, ViewLikeOpInterface, AllElementTypesMatch<["src", "res"]>]> {
-  let description = [{
-    Creates a subview of a memory descriptor. The resulting memory descriptor can have
-    a lower rank than the source; in this case, the result dimensions correspond to the
-    higher-order dimensions of the source memory descriptor.
-
-    Arguments:
-     - `src` : a memory descriptor.
-     - `offsets` : the coordinates within the matrix the subview will be created from.
-
-    Results:
-    - `res` : a memory descriptor with smaller size.
-
-  }];
-  let arguments = (ins XeGPU_MemDesc:$src,
-                       Variadic<Index>:$offsets,
-                       DenseI64ArrayAttr:$const_offsets);
-  let results = (outs XeGPU_MemDesc:$res);
-  let assemblyFormat = [{$src `` custom<DynamicIndexList>($offsets, $const_offsets) prop-dict
-                         attr-dict `` `:` qualified(type($src)) `->` qualified(type($res))}];
-  let builders = [
-    OpBuilder<(ins "Type": $res, "Value":$src, "llvm::ArrayRef<OpFoldResult>": $offsets)>
-  ];
-
-  let extraClassDeclaration = [{
-    mlir::Value getViewSource() { return getSrc(); }
-
-    SmallVector<OpFoldResult> getMixedOffsets() {
-      return getMixedValues(getConstOffsets(), getOffsets(), getContext());
-    }
-  }];
-
-  let hasVerifier = 1;
-}
-
-
 #endif // MLIR_DIALECT_XEGPU_IR_XEGPUOPS_TD
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 99526159cd2e7..024ca2023c811 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -237,10 +237,10 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
       return MemDescType::get(getContext(), shape.value_or(getShape()), elementType, getMemLayout());
     }
 
-    ArrayAttr getStridesAttr() {
+    ArrayAttr getStrideAttr() {
       auto layout = getMemLayout();
       if (layout && layout.hasAttr("stride")) {
-        return layout.getStrides();
+        return layout.getStrideAttr();
       }
       // derive and return default strides
       SmallVector<int64_t> defaultStrides;
@@ -262,7 +262,7 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
     /// Heuristic to determine if the MemDesc uses column-major layout,
     /// based on the rank and the value of the first stride dimension.
     bool isColMajor() {
-      auto dim0 = dyn_cast<IntegerAttr>(getStridesAttr()[0]);
+      auto dim0 = dyn_cast<IntegerAttr>(getStrideAttr()[0]);
       return getRank() == 2 && dim0 && dim0.getInt() == 1;
     }
 
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index fdd29dd96cd55..9cf963e101816 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -534,18 +534,6 @@ class CreateMemDescOpPattern final
   }
 };
 
-class MemDescSubviewOpPattern final
-    : public OpConversionPattern<xegpu::MemDescSubviewOp> {
-public:
-  using OpConversionPattern<xegpu::MemDescSubviewOp>::OpConversionPattern;
-  LogicalResult
-  matchAndRewrite(xegpu::MemDescSubviewOp op, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    return rewriter.notifyMatchFailure(
-        op, "MemDescSubviewOp are not supported on Xe2/Xe3 architecture.");
-  }
-};
-
 template <typename OpType,
           typename = std::enable_if_t<llvm::is_one_of<
               OpType, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
@@ -1085,8 +1073,7 @@ void mlir::populateXeGPUToXeVMConversionPatterns(
       typeConverter, patterns.getContext());
   patterns.add<LoadStoreMatrixToXeVMPattern<xegpu::LoadMatrixOp>,
                LoadStoreMatrixToXeVMPattern<xegpu::StoreMatrixOp>,
-               CreateMemDescOpPattern, MemDescSubviewOpPattern>(
-      typeConverter, patterns.getContext());
+               CreateMemDescOpPattern>(typeConverter, patterns.getContext());
   patterns.add<FenceToXeVMPattern, DpasToXeVMPattern>(typeConverter,
                                                       patterns.getContext());
 }
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index dc880308e6b3a..1cfae28f31188 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -781,7 +781,7 @@ SmallVector<int64_t> MemDescType::getStrideShape() {
 
   SmallVector<int64_t> matrixShape(getShape().begin(), getShape().end());
 
-  ArrayAttr strideAttr = getStridesAttr();
+  ArrayAttr strideAttr = getStrideAttr();
   SmallVector<int64_t> strides;
   for (Attribute attr : strideAttr.getValue()) {
     strides.push_back(cast<IntegerAttr>(attr).getInt());
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 7108afffe99d5..f2d1b85fa1737 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -1135,40 +1135,6 @@ LogicalResult StoreMatrixOp::verify() {
                                [&]() { return emitError(); });
 }
 
-//===----------------------------------------------------------------------===//
-// XeGPU_MemDescSubviewOp
-//===----------------------------------------------------------------------===//
-
-void MemDescSubviewOp::build(OpBuilder &builder, OperationState &state,
-                             Type resTy, Value src,
-                             llvm::ArrayRef<OpFoldResult> offsets) {
-  llvm::SmallVector<Value> dynamicOffsets;
-  llvm::SmallVector<int64_t> staticOffsets;
-  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
-  auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
-  build(builder, state, resTy, src, dynamicOffsets, staticOffsetsAttr);
-}
-
-LogicalResult MemDescSubviewOp::verify() {
-  MemDescType srcTy = getSrc().getType();
-  MemDescType resTy = getRes().getType();
-  ArrayRef<int64_t> srcShape = srcTy.getShape();
-  ArrayRef<int64_t> resShape = resTy.getShape();
-
-  if (srcTy.getRank() < resTy.getRank())
-    return emitOpError("result rank must not exceed source rank.");
-
-  if (llvm::any_of(
-          llvm::zip_equal(resShape, srcShape.take_back(resShape.size())),
-          [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
-    return emitOpError("result shape must not exceed source shape.");
-
-  if (srcTy.getStridesAttr() != resTy.getStridesAttr())
-    return emitOpError("result must inherit the source strides.");
-
-  return success();
-}
-
 } // namespace xegpu
 } // namespace mlir
 
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 6062eba709b88..ebbe3ce0ec0d0 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -913,31 +913,3 @@ func.func @store_mem_desc_invalid_attr2(%arg0: !xegpu.mem_desc<16x64xf16>, %data
   return
 }
 
-// -----
-func.func @mem_desc_subview_size_mismatch(%arg0: !xegpu.mem_desc<16x64xf16>) {
-  // expected-error at +1 {{result shape must not exceed source shape}}
-  %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<32x16xf16>
-  return
-}
-
-// -----
-func.func @mem_desc_subview_layout_mismatch(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride =[1, 16]>>) {
-  // expected-error at +1 {{result must inherit the source strides}}
-  %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride =[1, 16]>> -> !xegpu.mem_desc<8x16xf16>
-  return
-}
-
-// -----
-func.func @mem_desc_subview_element_type_mismatch(%arg0: !xegpu.mem_desc<16x64xf16>) {
-  // expected-error at +1 {{failed to verify that all of {src, res} have same element type}}
-  %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<8x16xf32, #xegpu.mem_layout<stride =[64, 1]>>
-  return
-}
-
-// -----
-func.func @mem_desc_subview_rank_mismatch(%arg0: !xegpu.mem_desc<16x64xf16>) {
-  // expected-error at +1 {{result rank must not exceed source rank}}
-  %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<4x8x16xf16>
-  return
-}
-
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index f1f5f86d33bc0..0a10f6814ae96 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -895,25 +895,4 @@ gpu.func @simt_store_matrix_vector(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_
   gpu.return
 }
 
-// CHECK: gpu.func @mem_desc_subview([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>)
-gpu.func @mem_desc_subview(%arg0: !xegpu.mem_desc<16x64xf16>) {
-  //CHECK: xegpu.mem_desc_subview [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout<stride = [64, 1]>>
-  %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout<stride = [64, 1]>>
-  gpu.return
-}
-
-// CHECK: gpu.func @mem_desc_subview_lower_rank([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16>)
-gpu.func @mem_desc_subview_lower_rank(%arg0: !xegpu.mem_desc<16x64xf16>) {
-  //CHECK: xegpu.mem_desc_subview [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<16xf16, #xegpu.mem_layout<stride = [64, 1]>>
-  %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> !xegpu.mem_desc<16xf16, #xegpu.mem_layout<stride = [64, 1]>>
-  gpu.return
-}
-
-// CHECK: gpu.func @mem_desc_subview_with_stride([[ARG0:%.+]]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>)
-gpu.func @mem_desc_subview_with_stride(%arg0: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>>) {
-  //CHECK: xegpu.mem_desc_subview [[ARG0]][8, 8] : !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout<stride = [1, 16]>>
-  %data = xegpu.mem_desc_subview %arg0[8, 8]: !xegpu.mem_desc<16x64xf16, #xegpu.mem_layout<stride = [1, 16]>> -> !xegpu.mem_desc<8x16xf16, #xegpu.mem_layout<stride = [1, 16]>>
-  gpu.return
-}
-
 }

>From de87d094a9a309b66138d2a357d1ac73b8270c2b Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 14 Oct 2025 22:05:14 +0000
Subject: [PATCH 11/12] address minor comments

---
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 5 +----
 1 file changed, 1 insertion(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index f2d1b85fa1737..464a9e2d2a806 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -199,8 +199,7 @@ IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
     if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
                      [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
       return emitError() << "data shape must not exceed mem_desc shape.";
-  } else if (dataShape.size() == 1) {
-
+  } else {
     SmallVector<int64_t> blockShape = mdescTy.getBlockShape();
     // if the subgroup_block_io attribute is set,  mdescTy must have block
     // attribute
@@ -212,8 +211,6 @@ IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
     if (subgroup_block_io && mdescTy.isColMajor())
       return emitError() << "mem_desc should be row major when "
                             "subgroup_block_io is set.";
-  } else if (dataShape.size() == 0) {
-    return emitError() << "result shape must not be empty.";
   }
 
   return success();

>From faa0bfb3eb6004dcaf33269b9e161051a96baa79 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 15 Oct 2025 23:35:57 +0000
Subject: [PATCH 12/12] address comments

---
 mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td   |  6 ++++++
 mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td | 14 ++++++++------
 mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp  |  2 +-
 3 files changed, 15 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 73b70da9642e4..426377fcf598f 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -1319,6 +1319,9 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
     Arguments:
      - `mem_desc`: the memory descriptor identifying the SLM region.
      - `offsets`: the coordinates within the matrix to read from.
+     - `subgroup_block_io`: [optional] An attribute indicating that the operation can be 
+                 lowered to a subgroup block load. When this attribute is present, 
+                 the offsets are subgroup-uniform across all lanes.
      - `layout`: [optional] An attribute for guiding distributions among
                  subgroups and/or work-items. It currently can accept either
                  LayoutAttr or SliceAttr.
@@ -1367,6 +1370,9 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
      - `mem_desc`: the memory descriptor specifying the SLM region.
      - `offsets`: the coordinates within the matrix where the data will be written.
      - `data`: the values to be stored in the matrix.
+     - `subgroup_block_io`: [optional] An attribute indicating that the operation can be 
+                 lowered to a subgroup block store. When this attribute is present, 
+                 the offsets are subgroup-uniform across all lanes.     
      - `layout`: [optional] An attribute for guiding distributions among
                  subgroups and/or work-items. It currently can accept either
                  LayoutAttr or SliceAttr.
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 024ca2023c811..b1196fbe9c66a 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -263,10 +263,10 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
     /// based on the rank and the value of the first stride dimension.
     bool isColMajor() {
       auto dim0 = dyn_cast<IntegerAttr>(getStrideAttr()[0]);
-      return getRank() == 2 && dim0 && dim0.getInt() == 1;
+      return getRank() == 2 && dim0.getInt() == 1;
     }
 
-    // get the Blocking shape for a MemDescType, Which is represented
+    // Get the Blocking shape for a MemDescType, Which is represented
     // as an attribute in MemDescType. By default it is the shape
     // of the mdescTy
     SmallVector<int64_t> getBlockShape() {
@@ -284,16 +284,18 @@ def XeGPU_MemDesc: XeGPUTypeDef<"MemDesc", "mem_desc", [ShapedTypeInterface], "m
     // Get strides as vector of integer. 
     // If it contains block attribute, the strides are blocked strides.
     //
-    // The blocking is applied against the original matrix shape
-    // so that the linear offset is not impacted by the subview.
+    // The blocking is applied to the base matrix shape derived from the
+    // memory descriptor's stride information. If the matrix described by
+    // the memory descriptor is not contiguous, it is assumed that the base
+    // matrix is contiguous and follows the same memory layout.
     //
     // It first computes the original matrix shape using the stride info,
     // then computes the number of blocks in each dimension of original shape,
     // then compute the outer block shape and stride,
     // then combines the inner and outer block shape and stride
-    // e.g. for mem_desc<32x256xf16, @block=[16, 8], @strides=[1, 32]>
+    // e.g. for `mem_desc<32x256xf16, @block=[16, 8], @strides=[1, 32]>`
     // its memory layout tuple is ([2,32,16,8],[128,256,1,16])
-    // for  mem_desc<256x32xf16, @block=[8, 16]> with default @stride[32, 1]
+    // for  `mem_desc<256x32xf16, @block=[8, 16]>` with default @stride[32, 1]
     // its memory layout tuple is ([32,2,8,16],[256,128,16,1])
     SmallVector<int64_t> getStrideShape();
     
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index 9cf963e101816..9ee384e46ef33 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -520,7 +520,7 @@ class CreateMemDescOpPattern final
   matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
 
-    auto resTy = cast<xegpu::MemDescType>(op.getResult().getType());
+    auto resTy = op.getMemDesc();
 
     // Create the result MemRefType with the same shape, element type, and
     // memory space



More information about the Mlir-commits mailing list