[Mlir-commits] [mlir] [MLIR][XeGPU] Allow load/store/prefetch uses [memref+offset] instead of tdesc (PR #150576)

Jianhui Li llvmlistbot at llvm.org
Wed Jul 30 10:54:07 PDT 2025


https://github.com/Jianhui-Li updated https://github.com/llvm/llvm-project/pull/150576

>From 1373ffa1d836cb8401f5d24fca5c9283c2484d0e Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 17 Jul 2025 23:21:14 +0000
Subject: [PATCH 01/16] add optional offsets to nd load/store/prefetch

---
 .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 70 +++++++++++++++++--
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        | 32 +++++++--
 mlir/test/Dialect/XeGPU/ops.mlir              | 50 +++++++++++++
 3 files changed, 140 insertions(+), 12 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 81e25f7537cb0..e9f8437d7c102 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -29,9 +29,22 @@ class XeGPU_Op<string mnemonic, list<Trait> traits = []>:
     void printProperties(::mlir::MLIRContext *ctx,
             ::mlir::OpAsmPrinter &p, const Properties &prop,
             ::mlir::ArrayRef<::llvm::StringRef> elidedProps) {
-      Attribute propAttr = getPropertiesAsAttr(ctx, prop);
-      if (propAttr)
-        p << "<" << propAttr << ">";
+      
+      DictionaryAttr propAttr = dyn_cast_if_present<mlir::DictionaryAttr>(getPropertiesAsAttr(ctx, prop));
+
+      // filter out the elidedProps from propAttr, and get the resultAttr
+      mlir::SmallVector<mlir::NamedAttribute> filteredAttrs;
+      if (propAttr) {
+        for (auto namedAttr : propAttr.getValue()) {
+          if (llvm::is_contained(elidedProps, namedAttr.getName().strref()))
+            continue;
+          filteredAttrs.push_back(namedAttr);
+        }
+      }
+
+      if (!filteredAttrs.empty()) {
+        p << "<" << DictionaryAttr::get(ctx, filteredAttrs) << ">"; 
+      }
     }
 
     static ::mlir::ParseResult parseProperties(::mlir::OpAsmParser &parser,
@@ -288,6 +301,8 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
   }];
 
   let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
+                       Variadic<Index>: $offsets,
+                       OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
@@ -298,7 +313,18 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
     }
   }];
 
-  let assemblyFormat = "$TensorDesc prop-dict attr-dict `:` qualified(type($TensorDesc))";
+  let assemblyFormat = [{
+    $TensorDesc `` 
+    custom<OptionalDynamicIndexList>($offsets, $const_offsets) 
+    prop-dict attr-dict `:` qualified(type($TensorDesc))
+  }];
+
+  let builders = [
+    OpBuilder<(ins "Value": $TensorDesc, 
+                   "xegpu::CachePolicyAttr": $l1_hint, 
+                   "xegpu::CachePolicyAttr": $l2_hint, 
+                   "xegpu::CachePolicyAttr": $l3_hint)>
+  ];
 
   let hasVerifier = 1;
 }
@@ -343,6 +369,8 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
   }];
 
   let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
+                       Variadic<Index>: $offsets,
+                       OptionalAttr<DenseI64ArrayAttr>: $const_offsets,  
                        OptionalAttr<UnitAttr>: $packed,
                        OptionalAttr<DenseI64ArrayAttr>: $transpose,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
@@ -361,7 +389,20 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
     }
   }];
 
-  let assemblyFormat = "$TensorDesc prop-dict attr-dict `:` qualified(type($TensorDesc)) `->` type($value)";
+  let assemblyFormat = [{
+    $TensorDesc `` 
+    custom<OptionalDynamicIndexList>($offsets, $const_offsets) 
+    prop-dict attr-dict `:` qualified(type($TensorDesc)) `->` type($value)
+  }];
+
+  let builders = [
+    OpBuilder<(ins "Type": $value, "Value": $TensorDesc, 
+                    "UnitAttr": $packed, "DenseI64ArrayAttr": $transpose,
+                    "xegpu::CachePolicyAttr": $l1_hint, 
+                    "xegpu::CachePolicyAttr": $l2_hint, 
+                    "xegpu::CachePolicyAttr": $l3_hint)>
+  ];
+
   let hasVerifier = 1;
 }
 
@@ -400,6 +441,8 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
 
   let arguments = (ins XeGPU_ValueType: $value,
                        XeGPU_TensorDesc: $TensorDesc,
+                       Variadic<Index>: $offsets,
+                       OptionalAttr<DenseI64ArrayAttr>: $const_offsets,  
                        OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
@@ -414,8 +457,21 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
     }
   }];
 
-  let assemblyFormat = [{$value `,` $TensorDesc prop-dict attr-dict
-                        `:` type($value) `,` qualified(type($TensorDesc))}];
+   let assemblyFormat = [{
+    $value `,` 
+    $TensorDesc `` 
+    custom<OptionalDynamicIndexList>($offsets, $const_offsets) 
+    prop-dict attr-dict `:`  type($value) `,` qualified(type($TensorDesc))
+  }];
+
+  let builders = [
+    OpBuilder<(ins "Value": $value, "Value": $TensorDesc, 
+                   "xegpu::CachePolicyAttr": $l1_hint, 
+                   "xegpu::CachePolicyAttr": $l2_hint, 
+                   "xegpu::CachePolicyAttr": $l3_hint)>
+  ];
+
+
   let hasVerifier = 1;
 }
 
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 78cbf884a1911..ca3c92cf4b52c 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -331,16 +331,24 @@ ParseResult parseOptionalDynamicIndexList(
 
 void printOptionalDynamicIndexList(
     OpAsmPrinter &printer, Operation *op, OperandRange values,
-    ArrayRef<int64_t> integers, TypeRange valueTypes = TypeRange(),
-    AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
+    DenseI64ArrayAttr integers) {
 
-  return printDynamicIndexList(printer, op, values, integers,
-                               /*scalableFlags=*/{}, valueTypes, delimiter);
-}
+    if (!integers) 
+      return;
 
+    return printDynamicIndexList(printer, op, values, integers,
+                      /*scalableFlags=*/{}, {}, AsmParser::Delimiter::Square);
+   }
 //===----------------------------------------------------------------------===//
 // XeGPU_PrefetchNdOp
 //===----------------------------------------------------------------------===//
+
+void PrefetchNdOp::build(OpBuilder &builder, OperationState &state, Value tensorDesc, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) {
+
+  return build(builder, state, tensorDesc, ValueRange(), DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint);
+
+}
+
 LogicalResult PrefetchNdOp::verify() {
   auto tdescTy = getTensorDescType();
   if (tdescTy.isScattered())
@@ -361,6 +369,13 @@ LogicalResult PrefetchNdOp::verify() {
 //===----------------------------------------------------------------------===//
 // XeGPU_LoadNdOp
 //===----------------------------------------------------------------------===//
+
+void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType, Value tensorDesc, UnitAttr packed, DenseI64ArrayAttr transpose, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) {
+
+  return build(builder, state, retType, tensorDesc, ValueRange(), DenseI64ArrayAttr(), packed, transpose, l1_hint, l2_hint, l3_hint);
+
+}
+
 LogicalResult LoadNdOp::verify() {
   auto tdescTy = getTensorDescType();
   auto valueTy = getType();
@@ -448,6 +463,13 @@ LogicalResult LoadNdOp::verify() {
 //===----------------------------------------------------------------------===//
 // XeGPU_StoreNdOp
 //===----------------------------------------------------------------------===//
+
+void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value, Value tensorDesc, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) {
+
+  return build(builder, state, value, tensorDesc, ValueRange(), DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint);
+
+}
+
 LogicalResult StoreNdOp::verify() {
   auto dstTy = getTensorDescType(); // Tile
   auto valTy = getValueType();      // Vector
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index 695437354cd7c..a1028a8e8a2f3 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -121,6 +121,15 @@ gpu.func @prefetch_nd_2(%src: memref<8x24x32x48x64xf16>) {
   gpu.return
 }
 
+// CHECK: gpu.func @prefetch_nd_offset_1(%[[arg0:.*]]: memref<8x24x32x48x64xf16>) {
+gpu.func @prefetch_nd_offset_1(%src: memref<8x24x32x48x64xf16>) {
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0, 0, 0, 0] : memref<8x24x32x48x64xf16> -> !xegpu.tensor_desc<1x2x4x8x16xf16>
+  %1 = xegpu.create_nd_tdesc %src[0, 0, 0, 0, 0] : memref<8x24x32x48x64xf16> -> !xegpu.tensor_desc<1x2x4x8x16xf16>
+  // CHECK: xegpu.prefetch_nd %[[R0]][0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<1x2x4x8x16xf16>
+  xegpu.prefetch_nd %1[0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<1x2x4x8x16xf16>
+  gpu.return
+}
+
 // CHECK: func @subgroup_load_nd(%[[arg0:.*]]: memref<8x16xf16>) {
 gpu.func @subgroup_load_nd(%src: memref<8x16xf16>) {
   // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -260,6 +269,15 @@ gpu.func @subgroup_load_nd_8(%src: memref<24x32xf32>) {
   gpu.return
 }
 
+// CHECK: func @subgroup_load_nd_offset_1(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @subgroup_load_nd_offset_1(%src: memref<24x32xf32>) {
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
+  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]][0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32> -> vector<8x16xf32>
+  %2 = xegpu.load_nd %1[0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32> -> vector<8x16xf32>
+  gpu.return
+}
+
 // CHECK: func @simt_load_nd_8(%[[arg0:.*]]: memref<24x32xf32>) {
 gpu.func @simt_load_nd_8(%src: memref<24x32xf32>) {
   // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
@@ -269,6 +287,16 @@ gpu.func @simt_load_nd_8(%src: memref<24x32xf32>) {
   gpu.return
 }
 
+
+// CHECK: func @simt_load_nd_offset_1(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @simt_load_nd_offset_1(%src: memref<24x32xf32>) {
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<16x8xf32>
+  // CHECK: %[[R1:.*]] = xegpu.load_nd %[[R0]][0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32> -> vector<8xf32>
+  %2 = xegpu.load_nd %1[0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose = array<i64: 1, 0>}> : !xegpu.tensor_desc<16x8xf32> -> vector<8xf32>
+  gpu.return
+}
+
 // CHECK: func @subgroup_store_nd(%[[arg0:.*]]: memref<24x32xf16>) {
 gpu.func @subgroup_store_nd(%dst: memref<24x32xf16>) {
   // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<24x32xf16>
@@ -293,6 +321,17 @@ gpu.func @simt_store_nd(%src: memref<24x32xf16>) {
 
 // CHECK: func @subgroup_store_nd_2(%[[arg0:.*]]: memref<24x32xf16>) {
 gpu.func @subgroup_store_nd_2(%dst: memref<24x32xf16>) {
+  // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<32xf16>
+  %1 = arith.constant dense<1.0>: vector<32xf16>
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
+  %2 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
+  // CHECK: xegpu.store_nd %[[C]], %[[R0]][0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<32xf16>, !xegpu.tensor_desc<32xf16>
+  xegpu.store_nd %1, %2[0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<32xf16>, !xegpu.tensor_desc<32xf16>
+  gpu.return
+}
+
+// CHECK: func @subgroup_store_nd_offset_1(%[[arg0:.*]]: memref<24x32xf16>) {
+gpu.func @subgroup_store_nd_offset_1(%dst: memref<24x32xf16>) {
   // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<32xf16>
   %1 = arith.constant dense<1.0>: vector<32xf16>
   // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
@@ -313,6 +352,17 @@ gpu.func @simt_store_nd_2(%src: memref<24x32xf16>) {
   gpu.return
 }
 
+// CHECK: func @simt_store_nd_offset_1(%[[arg0:.*]]: memref<24x32xf16>) {
+gpu.func @simt_store_nd_offset_1(%src: memref<24x32xf16>) {
+  // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<2xf16>
+  %1 = arith.constant dense<1.0>: vector<2xf16>
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
+  %2 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
+  // CHECK: xegpu.store_nd %[[C]], %[[R0]][0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<2xf16>, !xegpu.tensor_desc<32xf16>
+  xegpu.store_nd %1, %2[0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<2xf16>, !xegpu.tensor_desc<32xf16>
+  gpu.return
+}
+
 // CHECK: gpu.func @update_nd_tdesc(%[[arg0:.*]]: memref<24x32xf32>) {
 gpu.func @update_nd_tdesc(%src: memref<24x32xf32>) {
   // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>

>From 30ff640a8d2c59d31effbb9828f1775032564c57 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 17 Jul 2025 23:22:18 +0000
Subject: [PATCH 02/16] git-clang-format

---
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 47 ++++++++++++++++----------
 1 file changed, 30 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index ca3c92cf4b52c..7cb105bf4292d 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -329,24 +329,28 @@ ParseResult parseOptionalDynamicIndexList(
   return success();
 }
 
-void printOptionalDynamicIndexList(
-    OpAsmPrinter &printer, Operation *op, OperandRange values,
-    DenseI64ArrayAttr integers) {
+void printOptionalDynamicIndexList(OpAsmPrinter &printer, Operation *op,
+                                   OperandRange values,
+                                   DenseI64ArrayAttr integers) {
 
-    if (!integers) 
-      return;
+  if (!integers)
+    return;
 
-    return printDynamicIndexList(printer, op, values, integers,
-                      /*scalableFlags=*/{}, {}, AsmParser::Delimiter::Square);
-   }
+  return printDynamicIndexList(printer, op, values, integers,
+                               /*scalableFlags=*/{}, {},
+                               AsmParser::Delimiter::Square);
+}
 //===----------------------------------------------------------------------===//
 // XeGPU_PrefetchNdOp
 //===----------------------------------------------------------------------===//
 
-void PrefetchNdOp::build(OpBuilder &builder, OperationState &state, Value tensorDesc, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) {
-
-  return build(builder, state, tensorDesc, ValueRange(), DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint);
+void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
+                         Value tensorDesc, xegpu::CachePolicyAttr l1_hint,
+                         xegpu::CachePolicyAttr l2_hint,
+                         xegpu::CachePolicyAttr l3_hint) {
 
+  return build(builder, state, tensorDesc, ValueRange(), DenseI64ArrayAttr(),
+               l1_hint, l2_hint, l3_hint);
 }
 
 LogicalResult PrefetchNdOp::verify() {
@@ -370,10 +374,16 @@ LogicalResult PrefetchNdOp::verify() {
 // XeGPU_LoadNdOp
 //===----------------------------------------------------------------------===//
 
-void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType, Value tensorDesc, UnitAttr packed, DenseI64ArrayAttr transpose, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) {
-
-  return build(builder, state, retType, tensorDesc, ValueRange(), DenseI64ArrayAttr(), packed, transpose, l1_hint, l2_hint, l3_hint);
+void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
+                     Value tensorDesc, UnitAttr packed,
+                     DenseI64ArrayAttr transpose,
+                     xegpu::CachePolicyAttr l1_hint,
+                     xegpu::CachePolicyAttr l2_hint,
+                     xegpu::CachePolicyAttr l3_hint) {
 
+  return build(builder, state, retType, tensorDesc, ValueRange(),
+               DenseI64ArrayAttr(), packed, transpose, l1_hint, l2_hint,
+               l3_hint);
 }
 
 LogicalResult LoadNdOp::verify() {
@@ -464,10 +474,13 @@ LogicalResult LoadNdOp::verify() {
 // XeGPU_StoreNdOp
 //===----------------------------------------------------------------------===//
 
-void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value, Value tensorDesc, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) {
-
-  return build(builder, state, value, tensorDesc, ValueRange(), DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint);
+void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
+                      Value tensorDesc, xegpu::CachePolicyAttr l1_hint,
+                      xegpu::CachePolicyAttr l2_hint,
+                      xegpu::CachePolicyAttr l3_hint) {
 
+  return build(builder, state, value, tensorDesc, ValueRange(),
+               DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint);
 }
 
 LogicalResult StoreNdOp::verify() {

>From efd1661b4ea93a08776213504219b96871449507 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 18 Jul 2025 01:48:18 +0000
Subject: [PATCH 03/16] add optional offsets to load_gather

---
 mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td    | 15 ++++++++++-----
 mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td  |  1 +
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp |  2 +-
 3 files changed, 12 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index e9f8437d7c102..31c2fb357371a 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -655,7 +655,7 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
 }
 
 def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
-  AllElementTypesMatch<["value", "TensorDesc"]>, MemoryEffects<[MemRead]>
+  AllElementTypesMatch<["value", "source"]>, MemoryEffects<[MemRead]>
   ]> {
   let summary = "load a set of scattered data points from memory.";
 
@@ -698,7 +698,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
 
   }];
 
-  let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
+  let arguments = (ins XeGPU_TensorDesc_or_MemRef: $source,
                        XeGPU_MaskType: $mask,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
@@ -706,8 +706,13 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
   let results = (outs XeGPU_ValueType: $value);
 
   let extraClassDeclaration = extraBaseClassDeclaration # [{
+
+    Type getSourceType() {
+      return getSource().getType();
+    }
+
     xegpu::TensorDescType getTensorDescType() {
-      return getTensorDesc().getType();
+       return dyn_cast<xegpu::TensorDescType>(getSourceType());
     }
 
     mlir::Type getElementType() {
@@ -725,8 +730,8 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
 
   }];
 
-  let assemblyFormat = [{$TensorDesc `,` $mask prop-dict attr-dict
-      `:` qualified(type($TensorDesc)) `,` type($mask) `->` type($value)}];
+  let assemblyFormat = [{$source `,` $mask prop-dict attr-dict
+      `:` qualified(type($source)) `,` type($mask) `->` type($value)}];
 
   let hasVerifier = 1;
 }
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 277158ac85409..ac41907655122 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -203,6 +203,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
   let genVerifyDecl = 1;
 }
 
+def XeGPU_TensorDesc_or_MemRef : AnyTypeOf<[XeGPU_TensorDesc,Non0RankedMemRefOf<[XeGPU_ScalarType]>]>;
 
 def XeGPU_Nbarrier: XeGPUTypeDef<"Nbarrier", "nbarrier", [], "mlir::Type"> {
   let summary = "!xegpu.nbarrier a custom XeGPU type representing a barrier.";
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index dc76441b27c02..44f2364d0caec 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -486,7 +486,7 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
     SmallVector<Type> convertedTdescTypes =
         getUnrolledTypes(tdescTy, *targetShape);
     SmallVector<Value> convertedTdescs = pack(
-        op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+        op.getSource(), convertedTdescTypes, *targetShape, loc, rewriter);
 
     SmallVector<Type> convertedMaskTypes;
     SmallVector<Value> convertedMasks;

>From 3578c1b96fd22a1e013bc70f987ac4bfb6849c10 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 17 Jul 2025 23:21:14 +0000
Subject: [PATCH 04/16] add optional offsets to nd load/store/prefetch

---
 mlir/test/Dialect/XeGPU/ops.mlir | 20 ++++++++++++++++++++
 1 file changed, 20 insertions(+)

diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index 3ebb1b969ac74..3523e3083c168 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -130,6 +130,15 @@ gpu.func @prefetch_nd_offset_1(%src: memref<48x64xf16>, %x : index, %y : index)
   gpu.return
 }
 
+// CHECK: gpu.func @prefetch_nd_offset_1(%[[arg0:.*]]: memref<8x24x32x48x64xf16>) {
+gpu.func @prefetch_nd_offset_1(%src: memref<8x24x32x48x64xf16>) {
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0, 0, 0, 0] : memref<8x24x32x48x64xf16> -> !xegpu.tensor_desc<1x2x4x8x16xf16>
+  %1 = xegpu.create_nd_tdesc %src[0, 0, 0, 0, 0] : memref<8x24x32x48x64xf16> -> !xegpu.tensor_desc<1x2x4x8x16xf16>
+  // CHECK: xegpu.prefetch_nd %[[R0]][0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<1x2x4x8x16xf16>
+  xegpu.prefetch_nd %1[0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<1x2x4x8x16xf16>
+  gpu.return
+}
+
 // CHECK: func @subgroup_load_nd(%[[arg0:.*]]: memref<8x16xf16>) {
 gpu.func @subgroup_load_nd(%src: memref<8x16xf16>) {
   // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -330,6 +339,17 @@ gpu.func @subgroup_store_nd_2(%dst: memref<24x32xf16>, %x : index) {
   gpu.return
 }
 
+// CHECK: func @subgroup_store_nd_offset_1(%[[arg0:.*]]: memref<24x32xf16>) {
+gpu.func @subgroup_store_nd_offset_1(%dst: memref<24x32xf16>) {
+  // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<32xf16>
+  %1 = arith.constant dense<1.0>: vector<32xf16>
+  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
+  %2 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
+  // CHECK: xegpu.store_nd %[[C]], %[[R0]][0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<32xf16>, !xegpu.tensor_desc<32xf16>
+  xegpu.store_nd %1, %2[0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<32xf16>, !xegpu.tensor_desc<32xf16>
+  gpu.return
+}
+
 // CHECK: func @subgroup_store_nd_offset_1(%[[arg0:.*]]: memref<24x32xf16>) {
 gpu.func @subgroup_store_nd_offset_1(%dst: memref<24x32xf16>) {
   // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<32xf16>

>From 59f7ea9bf601205496f5867ddf1445e1f51641fd Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 18 Jul 2025 01:48:18 +0000
Subject: [PATCH 05/16] add optional offsets to load_gather

---
 mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td    | 15 ++++++++++-----
 mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td  |  1 +
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp |  2 +-
 3 files changed, 12 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 91d6b2a5ead9b..a5a7dab1bf55a 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -655,7 +655,7 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
 }
 
 def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
-  AllElementTypesMatch<["value", "TensorDesc"]>, MemoryEffects<[MemRead]>
+  AllElementTypesMatch<["value", "source"]>, MemoryEffects<[MemRead]>
   ]> {
   let summary = "load a set of scattered data points from memory.";
 
@@ -698,7 +698,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
 
   }];
 
-  let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
+  let arguments = (ins XeGPU_TensorDesc_or_MemRef: $source,
                        XeGPU_MaskType: $mask,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
@@ -706,8 +706,13 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
   let results = (outs XeGPU_ValueType: $value);
 
   let extraClassDeclaration = extraBaseClassDeclaration # [{
+
+    Type getSourceType() {
+      return getSource().getType();
+    }
+
     xegpu::TensorDescType getTensorDescType() {
-      return getTensorDesc().getType();
+       return dyn_cast<xegpu::TensorDescType>(getSourceType());
     }
 
     mlir::Type getElementType() {
@@ -725,8 +730,8 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
 
   }];
 
-  let assemblyFormat = [{$TensorDesc `,` $mask prop-dict attr-dict
-      `:` qualified(type($TensorDesc)) `,` type($mask) `->` type($value)}];
+  let assemblyFormat = [{$source `,` $mask prop-dict attr-dict
+      `:` qualified(type($source)) `,` type($mask) `->` type($value)}];
 
   let hasVerifier = 1;
 }
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 20916ae9ef830..334f749ace745 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -189,6 +189,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
   let genVerifyDecl = 1;
 }
 
+def XeGPU_TensorDesc_or_MemRef : AnyTypeOf<[XeGPU_TensorDesc,Non0RankedMemRefOf<[XeGPU_ScalarType]>]>;
 
 def XeGPU_Nbarrier: XeGPUTypeDef<"Nbarrier", "nbarrier", [], "mlir::Type"> {
   let summary = "!xegpu.nbarrier a custom XeGPU type representing a barrier.";
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index a6208b455aa35..c8f332184bd1b 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -502,7 +502,7 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
     SmallVector<Type> convertedTdescTypes =
         getUnrolledTypes(tdescTy, *targetShape);
     SmallVector<Value> convertedTdescs = pack(
-        op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
+        op.getSource(), convertedTdescTypes, *targetShape, loc, rewriter);
 
     SmallVector<Type> convertedMaskTypes;
     SmallVector<Value> convertedMasks;

>From abc84c759f0993d8aa699ba1b87fdad7c5760a69 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Mon, 21 Jul 2025 04:18:01 +0000
Subject: [PATCH 06/16] add offsets to load

---
 mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 18 ++++++++++++++++--
 .../mlir/Dialect/XeGPU/IR/XeGPUTypes.td        |  2 +-
 2 files changed, 17 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index a5a7dab1bf55a..51356e963e778 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -699,6 +699,8 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
   }];
 
   let arguments = (ins XeGPU_TensorDesc_or_MemRef: $source,
+                       Variadic<Index>: $offsets,
+                       OptionalAttr<DenseI64ArrayAttr>: $const_offsets,  
                        XeGPU_MaskType: $mask,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
@@ -730,8 +732,20 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
 
   }];
 
-  let assemblyFormat = [{$source `,` $mask prop-dict attr-dict
-      `:` qualified(type($source)) `,` type($mask) `->` type($value)}];
+  let assemblyFormat = [{
+    $source `,` 
+    custom<OptionalDynamicIndexList>($offsets, $const_offsets) 
+    $mask prop-dict 
+    attr-dict `:` qualified(type($source)) `,` type($mask) `->` type($value)
+  }];
+
+//  let builders = [
+//    OpBuilder<(ins "Type": $value, "Value": $TensorDesc, 
+//                    "UnitAttr": $packed, "DenseI64ArrayAttr": $transpose,
+//                    "xegpu::CachePolicyAttr": $l1_hint, 
+//                    "xegpu::CachePolicyAttr": $l2_hint, 
+//                    "xegpu::CachePolicyAttr": $l3_hint)>
+//  ];
 
   let hasVerifier = 1;
 }
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 334f749ace745..8e575e31255a7 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -189,7 +189,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
   let genVerifyDecl = 1;
 }
 
-def XeGPU_TensorDesc_or_MemRef : AnyTypeOf<[XeGPU_TensorDesc,Non0RankedMemRefOf<[XeGPU_ScalarType]>]>;
+def XeGPU_TensorDesc_or_MemRef : AnyTypeOf<[XeGPU_TensorDesc,Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64]>;
 
 def XeGPU_Nbarrier: XeGPUTypeDef<"Nbarrier", "nbarrier", [], "mlir::Type"> {
   let summary = "!xegpu.nbarrier a custom XeGPU type representing a barrier.";

>From 80b4462f48c164309f70b1a6dbeb5805d869c998 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Thu, 24 Jul 2025 02:58:15 +0000
Subject: [PATCH 07/16] add chunk_size and use XeGPU_offsetType

---
 .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 84 +++++++++++++++----
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        | 27 +++++-
 .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp  |  2 +-
 mlir/test/Dialect/XeGPU/ops.mlir              | 33 +++-----
 4 files changed, 105 insertions(+), 41 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 51356e963e778..bf036d86d14bb 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -16,6 +16,7 @@ include "mlir/Dialect/XeGPU/IR/XeGPUTypes.td"
 include "mlir/Interfaces/ShapedOpInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/ViewLikeInterface.td"
+include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td"
 
 // Base class for dialect operations. This operation inherits from the base
 // `Op` class in OpBase.td, and provides:
@@ -638,18 +639,39 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
 
   }];
 
-  let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
+  let arguments = (ins XeGPU_TensorDesc_or_MemRef: $source,
+                       Optional<XeGPU_OffsetType>: $offsets,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
 
   let extraClassDeclaration = extraBaseClassDeclaration # [{
+    Type getSourceType() {
+      return getSource().getType();
+    }
+
+    Value getTensorDesc() {
+       return getSource();
+    }
+
     xegpu::TensorDescType getTensorDescType() {
-      return getTensorDesc().getType();
+      return dyn_cast<xegpu::TensorDescType>(getSourceType());
     }
   }];
 
-  let assemblyFormat = "$TensorDesc prop-dict attr-dict `:` qualified(type($TensorDesc))";
+  let assemblyFormat = [{
+    $source 
+    (`,` $offsets^)? 
+    prop-dict
+    attr-dict `:` type($source) (`,` type($offsets)^)?
+  }];
+
+  let builders = [
+    OpBuilder<(ins "Value": $source,
+                    "xegpu::CachePolicyAttr": $l1_hint, 
+                    "xegpu::CachePolicyAttr": $l2_hint, 
+                    "xegpu::CachePolicyAttr": $l3_hint)>
+  ];
 
   let hasVerifier = 1;
 }
@@ -702,6 +724,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
                        Variadic<Index>: $offsets,
                        OptionalAttr<DenseI64ArrayAttr>: $const_offsets,  
                        XeGPU_MaskType: $mask,
+                       OptionalAttr<I64Attr>: $chunk_size,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
@@ -713,6 +736,10 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
       return getSource().getType();
     }
 
+    Value getTensorDesc() {
+       return getSource();
+    }
+
     xegpu::TensorDescType getTensorDescType() {
        return dyn_cast<xegpu::TensorDescType>(getSourceType());
     }
@@ -733,25 +760,24 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
   }];
 
   let assemblyFormat = [{
-    $source `,` 
-    custom<OptionalDynamicIndexList>($offsets, $const_offsets) 
+    $source `` 
+    custom<OptionalDynamicIndexList>($offsets, $const_offsets) `,`
     $mask prop-dict 
     attr-dict `:` qualified(type($source)) `,` type($mask) `->` type($value)
   }];
 
-//  let builders = [
-//    OpBuilder<(ins "Type": $value, "Value": $TensorDesc, 
-//                    "UnitAttr": $packed, "DenseI64ArrayAttr": $transpose,
-//                    "xegpu::CachePolicyAttr": $l1_hint, 
-//                    "xegpu::CachePolicyAttr": $l2_hint, 
-//                    "xegpu::CachePolicyAttr": $l3_hint)>
-//  ];
+  let builders = [
+    OpBuilder<(ins "Type": $value, "Value": $source, "Value": $mask,
+                    "xegpu::CachePolicyAttr": $l1_hint, 
+                    "xegpu::CachePolicyAttr": $l2_hint, 
+                    "xegpu::CachePolicyAttr": $l3_hint)>
+   ];
 
   let hasVerifier = 1;
 }
 
 def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
-  AllElementTypesMatch<["value", "TensorDesc"]>, MemoryEffects<[MemWrite]>
+  AllElementTypesMatch<["value", "dest"]>, MemoryEffects<[MemWrite]>
   ]> {
   let summary = "store data to scattered memory locations.";
   let description = [{ It (aka. store) stores data to scattered memory locations. The value is
@@ -791,15 +817,26 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
 
   let arguments = (ins
     XeGPU_ValueType: $value,
-    XeGPU_TensorDesc: $TensorDesc,
+    XeGPU_TensorDesc_or_MemRef: $dest,
+    Variadic<Index>: $offsets,
+    OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
     XeGPU_MaskType: $mask,
+    OptionalAttr<I64Attr>: $chunk_size,
     OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
     OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
     OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
 
   let extraClassDeclaration = extraBaseClassDeclaration # [{
+    Type getDestType() {
+      return getDest().getType();
+    }
+
+    Value getTensorDesc() {
+       return getDest();
+    }
+
     xegpu::TensorDescType getTensorDescType() {
-      return getTensorDesc().getType();
+      return dyn_cast<xegpu::TensorDescType>(getDestType());
     }
 
     VectorType getValueType() {
@@ -811,8 +848,21 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
     }
   }];
 
-  let assemblyFormat = [{$value `,` $TensorDesc `,` $mask prop-dict attr-dict
-            `:` type($value) `,` qualified(type($TensorDesc)) `,` type($mask)}];
+  let assemblyFormat = [{
+    $value `,`
+    $dest `` 
+    custom<OptionalDynamicIndexList>($offsets, $const_offsets) `,`
+    $mask 
+    prop-dict 
+    attr-dict `:` type($value) `,` qualified(type($dest)) `,` type($mask)
+  }];
+
+  let builders = [
+    OpBuilder<(ins "Value": $value, "Value": $dest, "Value": $mask,
+                    "xegpu::CachePolicyAttr": $l1_hint, 
+                    "xegpu::CachePolicyAttr": $l2_hint, 
+                    "xegpu::CachePolicyAttr": $l3_hint)>
+   ];
 
   let hasVerifier = 1;
 }
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 704deeaa1f26b..4f3b3ed475afc 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -644,7 +644,7 @@ LogicalResult CreateDescOp::verify() {
 //===----------------------------------------------------------------------===//
 LogicalResult PrefetchOp::verify() {
   auto tdescTy = getTensorDescType();
-  if (!tdescTy.isScattered())
+  if (tdescTy && !tdescTy.isScattered())
     return emitOpError("Expects a scattered TensorDesc.\n");
 
   if (!isReadHintOrNone(getL1HintAttr()))
@@ -659,6 +659,13 @@ LogicalResult PrefetchOp::verify() {
   return success();
 }
 
+void PrefetchOp::build(OpBuilder &builder, OperationState &state, Value source,
+                       xegpu::CachePolicyAttr l1_hint,
+                       xegpu::CachePolicyAttr l2_hint,
+                       xegpu::CachePolicyAttr l3_hint) {
+  build(builder, state, source, Value(), l1_hint, l2_hint, l3_hint);
+}
+
 //===----------------------------------------------------------------------===//
 // XeGPU_LoadGatherOp
 //===----------------------------------------------------------------------===//
@@ -680,6 +687,15 @@ LogicalResult LoadGatherOp::verify() {
                                     [&]() { return emitOpError(); });
 }
 
+void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
+                         Type valueType, Value source, Value mask,
+                         xegpu::CachePolicyAttr l1_hint,
+                         xegpu::CachePolicyAttr l2_hint,
+                         xegpu::CachePolicyAttr l3_hint) {
+  build(builder, state, valueType, source, ValueRange(), DenseI64ArrayAttr(),
+        mask, IntegerAttr(), l1_hint, l2_hint, l3_hint);
+}
+
 //===----------------------------------------------------------------------===//
 // XeGPU_StoreScatterOp
 //===----------------------------------------------------------------------===//
@@ -701,6 +717,15 @@ LogicalResult StoreScatterOp::verify() {
                                     [&]() { return emitOpError(); });
 }
 
+void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
+                           Value value, Value dest, Value mask,
+                           xegpu::CachePolicyAttr l1_hint,
+                           xegpu::CachePolicyAttr l2_hint,
+                           xegpu::CachePolicyAttr l3_hint) {
+  build(builder, state, value, dest, ValueRange(), DenseI64ArrayAttr(), mask,
+        IntegerAttr(), l1_hint, l2_hint, l3_hint);
+}
+
 //===----------------------------------------------------------------------===//
 // XeGPU_UpdateOffsetOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index c8f332184bd1b..a6208b455aa35 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -502,7 +502,7 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
     SmallVector<Type> convertedTdescTypes =
         getUnrolledTypes(tdescTy, *targetShape);
     SmallVector<Value> convertedTdescs = pack(
-        op.getSource(), convertedTdescTypes, *targetShape, loc, rewriter);
+        op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
 
     SmallVector<Type> convertedMaskTypes;
     SmallVector<Value> convertedMasks;
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index 3523e3083c168..98836adaa57a3 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -130,15 +130,6 @@ gpu.func @prefetch_nd_offset_1(%src: memref<48x64xf16>, %x : index, %y : index)
   gpu.return
 }
 
-// CHECK: gpu.func @prefetch_nd_offset_1(%[[arg0:.*]]: memref<8x24x32x48x64xf16>) {
-gpu.func @prefetch_nd_offset_1(%src: memref<8x24x32x48x64xf16>) {
-  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0, 0, 0, 0] : memref<8x24x32x48x64xf16> -> !xegpu.tensor_desc<1x2x4x8x16xf16>
-  %1 = xegpu.create_nd_tdesc %src[0, 0, 0, 0, 0] : memref<8x24x32x48x64xf16> -> !xegpu.tensor_desc<1x2x4x8x16xf16>
-  // CHECK: xegpu.prefetch_nd %[[R0]][0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<1x2x4x8x16xf16>
-  xegpu.prefetch_nd %1[0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<1x2x4x8x16xf16>
-  gpu.return
-}
-
 // CHECK: func @subgroup_load_nd(%[[arg0:.*]]: memref<8x16xf16>) {
 gpu.func @subgroup_load_nd(%src: memref<8x16xf16>) {
   // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -339,19 +330,8 @@ gpu.func @subgroup_store_nd_2(%dst: memref<24x32xf16>, %x : index) {
   gpu.return
 }
 
-// CHECK: func @subgroup_store_nd_offset_1(%[[arg0:.*]]: memref<24x32xf16>) {
-gpu.func @subgroup_store_nd_offset_1(%dst: memref<24x32xf16>) {
-  // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<32xf16>
-  %1 = arith.constant dense<1.0>: vector<32xf16>
-  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
-  %2 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
-  // CHECK: xegpu.store_nd %[[C]], %[[R0]][0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<32xf16>, !xegpu.tensor_desc<32xf16>
-  xegpu.store_nd %1, %2[0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<32xf16>, !xegpu.tensor_desc<32xf16>
-  gpu.return
-}
-
-// CHECK: func @subgroup_store_nd_offset_1(%[[arg0:.*]]: memref<24x32xf16>) {
-gpu.func @subgroup_store_nd_offset_1(%dst: memref<24x32xf16>) {
+// CHECK: func @subgroup_store_nd_3(%[[arg0:.*]]: memref<24x32xf16>) {
+gpu.func @subgroup_store_nd_3(%dst: memref<24x32xf16>) {
   // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<32xf16>
   %1 = arith.constant dense<1.0>: vector<32xf16>
   // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
@@ -658,6 +638,15 @@ gpu.func @prefetch(%src: ui64) {
 }
 
 
+// CHECK: gpu.func @prefetch_offset(%[[arg0:.*]]: ui64) {
+gpu.func @prefetch_offset(%src: ui64) {
+  //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  // CHECK: xegpu.prefetch %[[arg0]], %cst <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : ui64, vector<4xindex>
+  xegpu.prefetch %src, %0 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: ui64, vector<4xindex>
+  gpu.return
+}
+
 // CHECK: gpu.func @create_update_tdesc(%[[arg0:.*]]: ui64) {
 gpu.func @create_update_tdesc(%src: ui64) {
   //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>

>From 769bf19a3775cc8ceacbe9077371c4c712c9f493 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 25 Jul 2025 02:33:06 +0000
Subject: [PATCH 08/16] add tests

---
 .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 31 +++---
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        | 97 +++++++++++++++++--
 mlir/test/Dialect/XeGPU/invalid.mlir          | 22 +++++
 mlir/test/Dialect/XeGPU/ops.mlir              | 40 ++++----
 4 files changed, 147 insertions(+), 43 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index bf036d86d14bb..c6b192a9dda31 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -676,9 +676,7 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
   let hasVerifier = 1;
 }
 
-def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
-  AllElementTypesMatch<["value", "source"]>, MemoryEffects<[MemRead]>
-  ]> {
+def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
   let summary = "load a set of scattered data points from memory.";
 
   let description = [{ It (aka. load) load data per each work-item. The output
@@ -721,8 +719,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
   }];
 
   let arguments = (ins XeGPU_TensorDesc_or_MemRef: $source,
-                       Variadic<Index>: $offsets,
-                       OptionalAttr<DenseI64ArrayAttr>: $const_offsets,  
+                       Optional<XeGPU_OffsetType>: $offsets,
                        XeGPU_MaskType: $mask,
                        OptionalAttr<I64Attr>: $chunk_size,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
@@ -760,11 +757,15 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
   }];
 
   let assemblyFormat = [{
-    $source `` 
-    custom<OptionalDynamicIndexList>($offsets, $const_offsets) `,`
+    $source
+    (`[` $offsets^ `]`)? `,`
     $mask prop-dict 
-    attr-dict `:` qualified(type($source)) `,` type($mask) `->` type($value)
+    attr-dict `:` type(operands) `->` type($value)
   }];
+  
+  //    functional-type(operands, results)
+  // type($source) (type($offsets)^ )? `,` type($mask) `->` type($value)
+  
 
   let builders = [
     OpBuilder<(ins "Type": $value, "Value": $source, "Value": $mask,
@@ -776,9 +777,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
   let hasVerifier = 1;
 }
 
-def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
-  AllElementTypesMatch<["value", "dest"]>, MemoryEffects<[MemWrite]>
-  ]> {
+def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
   let summary = "store data to scattered memory locations.";
   let description = [{ It (aka. store) stores data to scattered memory locations. The value is
   typically a 1D vector. But when the chunk size of the TensorDesc is larger than 1, it will be
@@ -818,8 +817,7 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
   let arguments = (ins
     XeGPU_ValueType: $value,
     XeGPU_TensorDesc_or_MemRef: $dest,
-    Variadic<Index>: $offsets,
-    OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
+    Optional<XeGPU_OffsetType>: $offsets,
     XeGPU_MaskType: $mask,
     OptionalAttr<I64Attr>: $chunk_size,
     OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
@@ -850,12 +848,13 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
 
   let assemblyFormat = [{
     $value `,`
-    $dest `` 
-    custom<OptionalDynamicIndexList>($offsets, $const_offsets) `,`
+    $dest
+    (`[` $offsets^ `]`)? `,`
     $mask 
     prop-dict 
-    attr-dict `:` type($value) `,` qualified(type($dest)) `,` type($mask)
+    attr-dict `:`  type(operands)
   }];
+//    type($value) `,` qualified(type($dest)) (type($offsets)^)? `,` type($mask)
 
   let builders = [
     OpBuilder<(ins "Value": $value, "Value": $dest, "Value": $mask,
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 4f3b3ed475afc..7a32f1a45c762 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -110,6 +110,66 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
   return success();
 }
 
+static LogicalResult
+isValidGatherScatterMemRefParams(Type maskTy, VectorType valueTy,
+                                 MemRefType memTy, int64_t chunkSize,
+                                 function_ref<InFlightDiagnostic()> emitError) {
+
+  if (!valueTy)
+    return emitError() << "Expecting a vector type result.";
+
+  auto maskShape = getShapeOf(maskTy);
+  auto valueShape = getShapeOf(valueTy);
+  auto memShape = getShapeOf(memTy);
+
+  if (valueTy.getElementType() != memTy.getElementType())
+    return emitError() << "Value should have the same element type as MemRef.";
+
+  // a valid shape for SIMT case
+  if (valueTy.getRank() == 1) {
+    if (valueTy.getNumElements() != chunkSize)
+      return emitError() << "value elements must match chunk size " << chunkSize
+                         << " for SIMT code.";
+    return success();
+  }
+
+  llvm::SmallVector<int64_t> expectedMaskShape(valueShape);
+  if (chunkSize > 1)
+    expectedMaskShape.pop_back();
+  if (expectedMaskShape != maskShape)
+    return emitError() << "Mask should match value except the chunk size dim.";
+
+  return success();
+}
+
+static LogicalResult
+isValidGatherScatterRawptrParams(Type maskTy, VectorType valueTy,
+                                 int64_t chunkSize,
+                                 function_ref<InFlightDiagnostic()> emitError) {
+
+  if (!valueTy)
+    return emitError() << "Expecting a vector type result.";
+
+  auto maskShape = getShapeOf(maskTy);
+  auto valueShape = getShapeOf(valueTy);
+
+  // a valid shape for SIMT case
+  if (valueTy.getRank() == 1) {
+    if (valueTy.getNumElements() != chunkSize)
+      return emitError() << "value elements must match chunk size " << chunkSize
+                         << " for SIMT code.";
+    return success();
+  }
+
+  llvm::SmallVector<int64_t> expectedMaskShape(valueShape);
+  if (chunkSize > 1)
+    expectedMaskShape.pop_back();
+  if (expectedMaskShape != maskShape)
+    return emitError() << "Mask should match value except the chunk size dim.";
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // XeGPU_CreateNdDescOp
 //===----------------------------------------------------------------------===//
@@ -683,8 +743,18 @@ LogicalResult LoadGatherOp::verify() {
   if (!isReadHintOrNone(getL3HintAttr()))
     return emitOpError("invalid l3_hint: ") << getL3HintAttr();
 
-  return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
-                                    [&]() { return emitOpError(); });
+  if (tdescTy)
+    return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
+                                      [&]() { return emitOpError(); });
+  auto srcTy = getSourceType();
+  uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
+  auto memTy = dyn_cast<MemRefType>(srcTy);
+
+  if (memTy)
+    return isValidGatherScatterMemRefParams(maskTy, valueTy, memTy, chunkSize,
+                                            [&]() { return emitOpError(); });
+  return isValidGatherScatterRawptrParams(maskTy, valueTy, chunkSize,
+                                          [&]() { return emitOpError(); });
 }
 
 void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
@@ -692,8 +762,8 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
                          xegpu::CachePolicyAttr l1_hint,
                          xegpu::CachePolicyAttr l2_hint,
                          xegpu::CachePolicyAttr l3_hint) {
-  build(builder, state, valueType, source, ValueRange(), DenseI64ArrayAttr(),
-        mask, IntegerAttr(), l1_hint, l2_hint, l3_hint);
+  build(builder, state, valueType, source, Value(), mask, IntegerAttr(),
+        l1_hint, l2_hint, l3_hint);
 }
 
 //===----------------------------------------------------------------------===//
@@ -713,8 +783,19 @@ LogicalResult StoreScatterOp::verify() {
   if (!isWriteHintOrNone(getL3HintAttr()))
     return emitOpError("invalid l3_hint: ") << getL3HintAttr();
 
-  return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
-                                    [&]() { return emitOpError(); });
+  if (tdescTy)
+    return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
+                                      [&]() { return emitOpError(); });
+
+  auto destTy = getDestType();
+  uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
+  auto memTy = dyn_cast<MemRefType>(destTy);
+
+  if (memTy)
+    return isValidGatherScatterMemRefParams(maskTy, valueTy, memTy, chunkSize,
+                                            [&]() { return emitOpError(); });
+  return isValidGatherScatterRawptrParams(maskTy, valueTy, chunkSize,
+                                          [&]() { return emitOpError(); });
 }
 
 void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
@@ -722,8 +803,8 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
                            xegpu::CachePolicyAttr l1_hint,
                            xegpu::CachePolicyAttr l2_hint,
                            xegpu::CachePolicyAttr l3_hint) {
-  build(builder, state, value, dest, ValueRange(), DenseI64ArrayAttr(), mask,
-        IntegerAttr(), l1_hint, l2_hint, l3_hint);
+  build(builder, state, value, dest, Value(), mask, IntegerAttr(), l1_hint,
+        l2_hint, l3_hint);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 0160bfee07bf2..af34add37f7ad 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -384,6 +384,28 @@ func.func @load_gather_vc_3(%src: ui64) {
   return
 }
 
+// -----
+func.func @load_offset(%src: ui64) {
+  %offsets = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  %mask = arith.constant dense<1>: vector<8xi1>
+  // expected-error at +1 {{Mask should match value except the chunk size dim}}
+  %2 = xegpu.load %src[%offsets], %mask
+        : ui64, vector<4xindex>, vector<8xi1>
+          -> vector<4x2xf32>
+  return
+}
+
+// -----
+func.func @store_offset(%src: ui64) {
+  %val = arith.constant dense<2.9>: vector<4x2xf16>
+  %offsets = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  %mask = arith.constant dense<1>: vector<8xi1>
+  // expected-error at +1 {{Mask should match value except the chunk size dim}}
+  xegpu.store %val, %src[%offsets], %mask
+        : vector<4x2xf16>, ui64, vector<4xindex>, vector<8xi1>
+  return
+}
+
 // -----
 func.func @store_scatter_vc_1(%src: memref<24x32xf32>) {
   %0 = arith.constant dense<1>: vector<4xi1>
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index f8c558d614ee6..16f5356a69f24 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -130,15 +130,6 @@ gpu.func @prefetch_nd_offset_1(%src: memref<48x64xf16>, %x : index, %y : index)
   gpu.return
 }
 
-// CHECK: gpu.func @prefetch_nd_offset_1(%[[arg0:.*]]: memref<8x24x32x48x64xf16>) {
-gpu.func @prefetch_nd_offset_1(%src: memref<8x24x32x48x64xf16>) {
-  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0, 0, 0, 0] : memref<8x24x32x48x64xf16> -> !xegpu.tensor_desc<1x2x4x8x16xf16>
-  %1 = xegpu.create_nd_tdesc %src[0, 0, 0, 0, 0] : memref<8x24x32x48x64xf16> -> !xegpu.tensor_desc<1x2x4x8x16xf16>
-  // CHECK: xegpu.prefetch_nd %[[R0]][0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<1x2x4x8x16xf16>
-  xegpu.prefetch_nd %1[0, 0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<1x2x4x8x16xf16>
-  gpu.return
-}
-
 // CHECK: func @subgroup_load_nd(%[[arg0:.*]]: memref<8x16xf16>) {
 gpu.func @subgroup_load_nd(%src: memref<8x16xf16>) {
   // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
@@ -339,16 +330,6 @@ gpu.func @subgroup_store_nd_2(%dst: memref<24x32xf16>, %x : index) {
   gpu.return
 }
 
-// CHECK: func @subgroup_store_nd_3(%[[arg0:.*]]: memref<24x32xf16>) {
-gpu.func @subgroup_store_nd_3(%dst: memref<24x32xf16>) {
-  // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<32xf16>
-  %1 = arith.constant dense<1.0>: vector<32xf16>
-  // CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
-  %2 = xegpu.create_nd_tdesc %dst[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<32xf16>
-  // CHECK: xegpu.store_nd %[[C]], %[[R0]][0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}> : vector<32xf16>, !xegpu.tensor_desc<32xf16>
-  xegpu.store_nd %1, %2[0] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>: vector<32xf16>, !xegpu.tensor_desc<32xf16>
-  gpu.return
-}
 
 // CHECK: func @subgroup_store_nd_offset_1(%[[arg0:.*]]: memref<24x32xf16>) {
 gpu.func @subgroup_store_nd_offset_1(%dst: memref<24x32xf16>) {
@@ -541,6 +522,16 @@ gpu.func @subgroup_load_4(%src: ui64) {
   gpu.return
 }
 
+// CHECK: gpu.func @subgroup_load_offset_1(%arg0: memref<?xf16>) {
+gpu.func @subgroup_load_offset_1(%src: memref<?xf16>) {
+  %offset = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  %mask = arith.constant dense<1>: vector<4xi1>
+  //CHECK: %[[R1:.*]] = xegpu.load %arg0[%cst], %cst_0 <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint<cached>}> : memref<?xf16>, vector<4xindex>, vector<4xi1> -> vector<4x2xf16>
+  %val = xegpu.load %src[%offset], %mask <{chunk_size=2, l1_hint = #xegpu.cache_hint<cached>}>
+      : memref<?xf16>, vector<4xindex>, vector<4xi1> -> vector<4x2xf16>
+  gpu.return
+}
+
 // CHECK: gpu.func @subgroup_store(%[[arg0:.*]]: ui64) {
 gpu.func @subgroup_store(%src: ui64) {
   //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
@@ -646,6 +637,17 @@ gpu.func @subgroup_store_4(%src: ui64) {
   gpu.return
 }
 
+// CHECK: gpu.func @subgroup_store_offset_1(%arg0: memref<?xf16>) {
+gpu.func @subgroup_store_offset_1(%dest: memref<?xf16>) {
+  %val = arith.constant dense<2.9>: vector<4x2xf16>
+  %offset = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  %mask = arith.constant dense<1>: vector<4xi1>
+  //CHECK: xegpu.store %[[R0:.*]], %arg0[%cst_0], %cst_1 <{chunk_size = 2 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<4x2xf16>, memref<?xf16>, vector<4xindex>, vector<4xi1>
+  xegpu.store %val, %dest[%offset], %mask <{chunk_size=2, l1_hint = #xegpu.cache_hint<cached>}>
+      : vector<4x2xf16>, memref<?xf16>, vector<4xindex>, vector<4xi1>
+  gpu.return
+}
+
 // CHECK: gpu.func @prefetch(%[[arg0:.*]]: ui64) {
 gpu.func @prefetch(%src: ui64) {
   //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>

>From 45537469eddccf41eabf689b19f550d8616442de Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 25 Jul 2025 05:15:34 +0000
Subject: [PATCH 09/16] add invalid tests

---
 .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td |  7 ++--
 .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp  |  6 +--
 mlir/test/Dialect/XeGPU/invalid.mlir          | 39 ++++++++++++++-----
 mlir/test/Dialect/XeGPU/ops.mlir              |  4 +-
 4 files changed, 39 insertions(+), 17 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index c6b192a9dda31..312db1402f58f 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -661,11 +661,12 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
 
   let assemblyFormat = [{
     $source 
-    (`,` $offsets^)? 
+    (`[` $offsets^ `]`)?
     prop-dict
-    attr-dict `:` type($source) (`,` type($offsets)^)?
+    attr-dict `:` type(operands) 
   }];
-
+    // type($source) (type($offsets)^)?
+    
   let builders = [
     OpBuilder<(ins "Value": $source,
                     "xegpu::CachePolicyAttr": $l1_hint, 
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index c8f332184bd1b..cafbf8d5ffc5e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -484,7 +484,7 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
     VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
     xegpu::TensorDescType tdescTy = op.getTensorDescType();
 
-    if (!tdescTy.isScattered())
+    if (!tdescTy || !tdescTy.isScattered())
       return failure();
 
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
@@ -546,7 +546,7 @@ struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
     Location loc = op.getLoc();
     xegpu::TensorDescType tdescTy = op.getTensorDescType();
 
-    if (!tdescTy.isScattered())
+    if (!tdescTy || !tdescTy.isScattered())
       return failure();
 
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
@@ -575,7 +575,7 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
     VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
     xegpu::TensorDescType tdescTy = op.getTensorDescType();
 
-    if (!tdescTy.isScattered())
+    if (!tdescTy || !tdescTy.isScattered())
       return failure();
 
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index af34add37f7ad..b56de88391803 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -385,27 +385,48 @@ func.func @load_gather_vc_3(%src: ui64) {
 }
 
 // -----
-func.func @load_offset(%src: ui64) {
+func.func @load_gather_offset_sg(%src: memref<?xf16>) {
   %offsets = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
   %mask = arith.constant dense<1>: vector<8xi1>
   // expected-error at +1 {{Mask should match value except the chunk size dim}}
   %2 = xegpu.load %src[%offsets], %mask
-        : ui64, vector<4xindex>, vector<8xi1>
-          -> vector<4x2xf32>
+        : memref<?xf16>, vector<4xindex>, vector<8xi1>
+          -> vector<4x2xf16>
+  return
+}
+
+// -----
+func.func @load_gather_offset_wi(%src: ui64) {
+  %mask = arith.constant dense<1>: vector<1xi1>
+  %offsets = arith.constant dense<[0]> : vector<1xindex>
+  // expected-error at +1 {{value elements must match chunk size}}
+  %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : ui64,  vector<1xindex>, vector<1xi1> -> vector<4xf32>
+  return
+}
+
+// -----
+func.func @store_scatter_offset(%src: memref<?xf16>) {
+  %val = arith.constant dense<2.9>: vector<4xf16>
+  %offsets = arith.constant dense<[0]> : vector<1xindex>
+  %mask = arith.constant dense<1>: vector<1xi1>
+  // expected-error at +1 {{value elements must match chunk size}}
+  xegpu.store %val, %src[%offsets], %mask 
+        : vector<4xf16>, memref<?xf16>, vector<1xindex>, vector<1xi1>
   return
 }
 
+
 // -----
-func.func @store_offset(%src: ui64) {
+func.func @load_gather_offset_wi(%src: ui64) {
   %val = arith.constant dense<2.9>: vector<4x2xf16>
-  %offsets = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
-  %mask = arith.constant dense<1>: vector<8xi1>
-  // expected-error at +1 {{Mask should match value except the chunk size dim}}
-  xegpu.store %val, %src[%offsets], %mask
-        : vector<4x2xf16>, ui64, vector<4xindex>, vector<8xi1>
+  %mask = arith.constant dense<1>: vector<1xi1>
+  %offsets = arith.constant dense<[0]> : vector<1xindex>
+  // expected-error at +1 {{value elements must match chunk size}}
+  %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : ui64,  vector<1xindex>, vector<1xi1> -> vector<4xf32>
   return
 }
 
+
 // -----
 func.func @store_scatter_vc_1(%src: memref<24x32xf32>) {
   %0 = arith.constant dense<1>: vector<4xi1>
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index 16f5356a69f24..ea80601ef5574 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -664,8 +664,8 @@ gpu.func @prefetch(%src: ui64) {
 gpu.func @prefetch_offset(%src: ui64) {
   //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
   %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
-  // CHECK: xegpu.prefetch %[[arg0]], %cst <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : ui64, vector<4xindex>
-  xegpu.prefetch %src, %0 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: ui64, vector<4xindex>
+  // CHECK: xegpu.prefetch %[[arg0]][%cst] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : ui64, vector<4xindex>
+  xegpu.prefetch %src[%0] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: ui64, vector<4xindex>
   gpu.return
 }
 

>From 1249794952a1e67a4f0ff54b7e4fa39d0704b42e Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 25 Jul 2025 05:28:50 +0000
Subject: [PATCH 10/16] small fixes

---
 mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td    | 6 ------
 mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp | 2 +-
 mlir/test/Dialect/XeGPU/invalid.mlir              | 2 +-
 3 files changed, 2 insertions(+), 8 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 312db1402f58f..82edd69f63694 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -665,7 +665,6 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
     prop-dict
     attr-dict `:` type(operands) 
   }];
-    // type($source) (type($offsets)^)?
     
   let builders = [
     OpBuilder<(ins "Value": $source,
@@ -763,10 +762,6 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
     $mask prop-dict 
     attr-dict `:` type(operands) `->` type($value)
   }];
-  
-  //    functional-type(operands, results)
-  // type($source) (type($offsets)^ )? `,` type($mask) `->` type($value)
-  
 
   let builders = [
     OpBuilder<(ins "Type": $value, "Value": $source, "Value": $mask,
@@ -855,7 +850,6 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
     prop-dict 
     attr-dict `:`  type(operands)
   }];
-//    type($value) `,` qualified(type($dest)) (type($offsets)^)? `,` type($mask)
 
   let builders = [
     OpBuilder<(ins "Value": $value, "Value": $dest, "Value": $mask,
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index cafbf8d5ffc5e..29ee864b8f34f 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -502,7 +502,7 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
     SmallVector<Type> convertedTdescTypes =
         getUnrolledTypes(tdescTy, *targetShape);
     SmallVector<Value> convertedTdescs = pack(
-        op.getSource(), convertedTdescTypes, *targetShape, loc, rewriter);
+        op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
 
     SmallVector<Type> convertedMaskTypes;
     SmallVector<Value> convertedMasks;
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index b56de88391803..b8e6a31d8d2f7 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -405,7 +405,7 @@ func.func @load_gather_offset_wi(%src: ui64) {
 }
 
 // -----
-func.func @store_scatter_offset(%src: memref<?xf16>) {
+func.func @store_scatter_offset_sg(%src: memref<?xf16>) {
   %val = arith.constant dense<2.9>: vector<4xf16>
   %offsets = arith.constant dense<[0]> : vector<1xindex>
   %mask = arith.constant dense<1>: vector<1xi1>

>From 5cfb24b395ef3e83f4e316c9e35b9d2f8cb72dd1 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Fri, 25 Jul 2025 21:39:07 +0000
Subject: [PATCH 11/16] address comments

---
 mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td   | 7 +++----
 mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td | 2 +-
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp           | 1 -
 3 files changed, 4 insertions(+), 6 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 82edd69f63694..9da015b65a6af 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -16,7 +16,6 @@ include "mlir/Dialect/XeGPU/IR/XeGPUTypes.td"
 include "mlir/Interfaces/ShapedOpInterfaces.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
 include "mlir/Interfaces/ViewLikeInterface.td"
-include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td"
 
 // Base class for dialect operations. This operation inherits from the base
 // `Op` class in OpBase.td, and provides:
@@ -639,7 +638,7 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
 
   }];
 
-  let arguments = (ins XeGPU_TensorDesc_or_MemRef: $source,
+  let arguments = (ins XeGPU_TensorDescOrMemRef: $source,
                        Optional<XeGPU_OffsetType>: $offsets,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
@@ -718,7 +717,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
 
   }];
 
-  let arguments = (ins XeGPU_TensorDesc_or_MemRef: $source,
+  let arguments = (ins XeGPU_TensorDescOrMemRef: $source,
                        Optional<XeGPU_OffsetType>: $offsets,
                        XeGPU_MaskType: $mask,
                        OptionalAttr<I64Attr>: $chunk_size,
@@ -812,7 +811,7 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
 
   let arguments = (ins
     XeGPU_ValueType: $value,
-    XeGPU_TensorDesc_or_MemRef: $dest,
+    XeGPU_TensorDescOrMemRef: $dest,
     Optional<XeGPU_OffsetType>: $offsets,
     XeGPU_MaskType: $mask,
     OptionalAttr<I64Attr>: $chunk_size,
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 8e575e31255a7..fa59bf2e40c4d 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -189,7 +189,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
   let genVerifyDecl = 1;
 }
 
-def XeGPU_TensorDesc_or_MemRef : AnyTypeOf<[XeGPU_TensorDesc,Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64]>;
+def XeGPU_TensorDescOrMemRef : AnyTypeOf<[XeGPU_TensorDesc,Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64]>;
 
 def XeGPU_Nbarrier: XeGPUTypeDef<"Nbarrier", "nbarrier", [], "mlir::Type"> {
   let summary = "!xegpu.nbarrier a custom XeGPU type representing a barrier.";
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 7a32f1a45c762..3a41b298e2aae 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -120,7 +120,6 @@ isValidGatherScatterMemRefParams(Type maskTy, VectorType valueTy,
 
   auto maskShape = getShapeOf(maskTy);
   auto valueShape = getShapeOf(valueTy);
-  auto memShape = getShapeOf(memTy);
 
   if (valueTy.getElementType() != memTy.getElementType())
     return emitError() << "Value should have the same element type as MemRef.";

>From 5940d191e865d06063a8ed97b804b90858bc78ba Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 29 Jul 2025 17:33:18 +0000
Subject: [PATCH 12/16] address feedback

---
 .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 46 ++++++++++-
 .../mlir/Dialect/XeGPU/IR/XeGPUTypes.td       |  2 +-
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        | 76 ++++++++-----------
 .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp  |  3 +
 mlir/test/Dialect/XeGPU/invalid.mlir          |  2 +-
 mlir/test/Dialect/XeGPU/ops.mlir              |  2 -
 6 files changed, 80 insertions(+), 51 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 9da015b65a6af..c864ce0c3d9cd 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -628,17 +628,28 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
     As compared to prefetch_nd, which works on non-scattered TensorDesc,
     it works on scattered TensorDesc instead.
 
-    Example:
+    Example 1:
     ```mlir
       xegpu.prefetch %tdesc {l1_hint = #xegpu.cache_hint<cached>,
                              l2_hint = #xegpu.cache_hint<cached>,
                              l3_hint = #xegpu.cache_hint<cached>}
         : !xegpu.tensor_desc<16xf16>
     ```
+    
+    Example 2:
+    A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc. It combines "create scattered TensorTdesc" and "prefetch with scattered TensorTdesc". The source operand could be a raw pointer (uint64_t). Please refer to create_tdesc for the restriction of memref. 
+    ```mlir
+      %a = memref.alloc() : memref<1024xf32>
+      %0 = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex>
+      xegpu.prefetch %a[%0] {l1_hint = #xegpu.cache_hint<cached>,
+                             l2_hint = #xegpu.cache_hint<cached>,
+                             l3_hint = #xegpu.cache_hint<cached>}
+        : memref<1024xf32>, vector<4xindex>
+    ```
 
   }];
 
-  let arguments = (ins XeGPU_TensorDescOrMemRef: $source,
+  let arguments = (ins XeGPU_GatherScatterSourceType: $source,
                        Optional<XeGPU_OffsetType>: $offsets,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
@@ -706,6 +717,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
           : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>,
             vector<16xi1> -> vector<16x8xf32>
   ```
+  
   Example 3 (SIMT mode):
   ```mlir
     %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>,
@@ -714,10 +726,22 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
           : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>
             vector<16xi1> -> vector<8xf32>
   ```
+  
+  Example 4:
+  A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc. It combines "create scattered TensorTdesc" and "load with scattered TensorTdesc". The source operand could be a raw pointer (uint64_t). Please refer to create_tdesc for the restriction of memref. 
+  ```mlir
+    %a = memref.alloc() : memref<1024xf32>
+    %offsets = vector.step : vector<16xindex>
+    %mask = vector.constant_mask [16]: vector<16xi1>
+    %val = xegpu.load %a[%offsets], %mask {l1_hint = #xegpu.cache_hint<cached>,
+                           l2_hint = #xegpu.cache_hint<cached>,
+                           l3_hint = #xegpu.cache_hint<cached>}
+      : memref<1024xf32>, vector<16xi1>, vector<16xindex> -> vector<16xf32>
+  ```
 
   }];
 
-  let arguments = (ins XeGPU_TensorDescOrMemRef: $source,
+  let arguments = (ins XeGPU_GatherScatterSourceType: $source,
                        Optional<XeGPU_OffsetType>: $offsets,
                        XeGPU_MaskType: $mask,
                        OptionalAttr<I64Attr>: $chunk_size,
@@ -807,11 +831,25 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
                              l3_hint = #xegpu.cache_hint<write_through>}>
           : vector<8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>> vector<16xi1>
   ```
+
+  Example 4:
+  A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc. It combines "create scattered TensorTdesc" and "store with scattered TensorTdesc". The dest operand could be a raw pointer (uint64_t). Please refer to create_tdesc for the restriction of memref. 
+  ```mlir
+    %a = memref.alloc() : memref<1024xf32>
+    %val = arith.constant dense<0.0> : vector<16xf32>
+    %offsets = vector.step : vector<16xindex>
+    %mask = vector.constant_mask [16]: vector<16xi1>
+    xegpu.store %val, %a[%offsets], %mask {l1_hint = #xegpu.cache_hint<cached>,
+                           l2_hint = #xegpu.cache_hint<cached>,
+                           l3_hint = #xegpu.cache_hint<cached>}
+      : memref<1024xf32>, vector<16xi1>, vector<16xindex> -> vector<16xf32>
+  ```
+
   }];
 
   let arguments = (ins
     XeGPU_ValueType: $value,
-    XeGPU_TensorDescOrMemRef: $dest,
+    XeGPU_GatherScatterSourceType: $dest,
     Optional<XeGPU_OffsetType>: $offsets,
     XeGPU_MaskType: $mask,
     OptionalAttr<I64Attr>: $chunk_size,
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index fa59bf2e40c4d..b268cabb5d266 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -189,7 +189,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
   let genVerifyDecl = 1;
 }
 
-def XeGPU_TensorDescOrMemRef : AnyTypeOf<[XeGPU_TensorDesc,Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64]>;
+def XeGPU_GatherScatterSourceType : AnyTypeOf<[XeGPU_TensorDesc,Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64]>;
 
 def XeGPU_Nbarrier: XeGPUTypeDef<"Nbarrier", "nbarrier", [], "mlir::Type"> {
   let summary = "!xegpu.nbarrier a custom XeGPU type representing a barrier.";
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 3a41b298e2aae..7c8ee7408dfd1 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -111,39 +111,7 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
 }
 
 static LogicalResult
-isValidGatherScatterMemRefParams(Type maskTy, VectorType valueTy,
-                                 MemRefType memTy, int64_t chunkSize,
-                                 function_ref<InFlightDiagnostic()> emitError) {
-
-  if (!valueTy)
-    return emitError() << "Expecting a vector type result.";
-
-  auto maskShape = getShapeOf(maskTy);
-  auto valueShape = getShapeOf(valueTy);
-
-  if (valueTy.getElementType() != memTy.getElementType())
-    return emitError() << "Value should have the same element type as MemRef.";
-
-  // a valid shape for SIMT case
-  if (valueTy.getRank() == 1) {
-    if (valueTy.getNumElements() != chunkSize)
-      return emitError() << "value elements must match chunk size " << chunkSize
-                         << " for SIMT code.";
-    return success();
-  }
-
-  llvm::SmallVector<int64_t> expectedMaskShape(valueShape);
-  if (chunkSize > 1)
-    expectedMaskShape.pop_back();
-  if (expectedMaskShape != maskShape)
-    return emitError() << "Mask should match value except the chunk size dim.";
-
-  return success();
-}
-
-static LogicalResult
-isValidGatherScatterRawptrParams(Type maskTy, VectorType valueTy,
-                                 int64_t chunkSize,
+isValidGatherScatterBufferParams(Type maskTy, VectorType valueTy, int64_t chunkSize,
                                  function_ref<InFlightDiagnostic()> emitError) {
 
   if (!valueTy)
@@ -703,8 +671,14 @@ LogicalResult CreateDescOp::verify() {
 //===----------------------------------------------------------------------===//
 LogicalResult PrefetchOp::verify() {
   auto tdescTy = getTensorDescType();
-  if (tdescTy && !tdescTy.isScattered())
-    return emitOpError("Expects a scattered TensorDesc.\n");
+
+  if (tdescTy) {
+    if (!tdescTy.isScattered())
+      return emitOpError("Expects a scattered TensorDesc.\n");
+  } else {
+    if (getRankOf(getSource()) > 1)
+      return emitOpError("Expecting the source is a 1D memref or pointer (uint64_t).");
+  }
 
   if (!isReadHintOrNone(getL1HintAttr()))
     return emitOpError("invalid l1_hint: ") << getL1HintAttr();
@@ -733,6 +707,14 @@ LogicalResult LoadGatherOp::verify() {
   auto maskTy = getMaskType();
   auto valueTy = getValueType();
 
+  if (tdescTy) {
+    if (!tdescTy.isScattered())
+      return emitOpError("Expects a scattered TensorDesc.\n");
+  } else {
+    if (getRankOf(getSource()) > 1)
+      return emitOpError("Expecting the source is a 1D memref or pointer (uint64_t).");
+  }
+
   if (!isReadHintOrNone(getL1HintAttr()))
     return emitOpError("invalid l1_hint: ") << getL1HintAttr();
 
@@ -749,10 +731,10 @@ LogicalResult LoadGatherOp::verify() {
   uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
   auto memTy = dyn_cast<MemRefType>(srcTy);
 
-  if (memTy)
-    return isValidGatherScatterMemRefParams(maskTy, valueTy, memTy, chunkSize,
-                                            [&]() { return emitOpError(); });
-  return isValidGatherScatterRawptrParams(maskTy, valueTy, chunkSize,
+  if (memTy && (valueTy.getElementType() != memTy.getElementType()) )
+    return emitError() << "Value should have the same element type as MemRef.";
+
+  return isValidGatherScatterBufferParams(maskTy, valueTy, chunkSize,
                                           [&]() { return emitOpError(); });
 }
 
@@ -773,6 +755,14 @@ LogicalResult StoreScatterOp::verify() {
   auto maskTy = getMaskType();
   auto valueTy = getValueType();
 
+  if (tdescTy) {
+    if (!tdescTy.isScattered())
+      return emitOpError("Expects a scattered TensorDesc.\n");
+  } else {
+    if (getRankOf(getDest()) > 1)
+      return emitOpError("Expecting the dest is a 1D memref or pointer (uint64_t).");    
+  }
+  
   if (!isWriteHintOrNone(getL1HintAttr()))
     return emitOpError("invalid l1_hint: ") << getL1HintAttr();
 
@@ -790,10 +780,10 @@ LogicalResult StoreScatterOp::verify() {
   uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
   auto memTy = dyn_cast<MemRefType>(destTy);
 
-  if (memTy)
-    return isValidGatherScatterMemRefParams(maskTy, valueTy, memTy, chunkSize,
-                                            [&]() { return emitOpError(); });
-  return isValidGatherScatterRawptrParams(maskTy, valueTy, chunkSize,
+  if (memTy && (valueTy.getElementType() != memTy.getElementType()) )
+    return emitError() << "Value should have the same element type as MemRef.";
+
+  return isValidGatherScatterBufferParams(maskTy, valueTy, chunkSize,
                                           [&]() { return emitOpError(); });
 }
 
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index 29ee864b8f34f..d52f7f2ac274a 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -484,6 +484,7 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
     VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
     xegpu::TensorDescType tdescTy = op.getTensorDescType();
 
+    // TODO: handle the unstructure source case (!tdesTy)
     if (!tdescTy || !tdescTy.isScattered())
       return failure();
 
@@ -546,6 +547,7 @@ struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
     Location loc = op.getLoc();
     xegpu::TensorDescType tdescTy = op.getTensorDescType();
 
+    // TODO: handle the unstructure source case (!tdesTy)
     if (!tdescTy || !tdescTy.isScattered())
       return failure();
 
@@ -575,6 +577,7 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
     VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
     xegpu::TensorDescType tdescTy = op.getTensorDescType();
 
+    // TODO: handle the unstructure source case (!tdesTy)
     if (!tdescTy || !tdescTy.isScattered())
       return failure();
 
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index b8e6a31d8d2f7..4cece4640634e 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -405,7 +405,7 @@ func.func @load_gather_offset_wi(%src: ui64) {
 }
 
 // -----
-func.func @store_scatter_offset_sg(%src: memref<?xf16>) {
+func.func @store_scatter_offset_wi(%src: memref<?xf16>) {
   %val = arith.constant dense<2.9>: vector<4xf16>
   %offsets = arith.constant dense<[0]> : vector<1xindex>
   %mask = arith.constant dense<1>: vector<1xi1>
diff --git a/mlir/test/Dialect/XeGPU/ops.mlir b/mlir/test/Dialect/XeGPU/ops.mlir
index ea80601ef5574..6be2371d4d7b2 100644
--- a/mlir/test/Dialect/XeGPU/ops.mlir
+++ b/mlir/test/Dialect/XeGPU/ops.mlir
@@ -330,7 +330,6 @@ gpu.func @subgroup_store_nd_2(%dst: memref<24x32xf16>, %x : index) {
   gpu.return
 }
 
-
 // CHECK: func @subgroup_store_nd_offset_1(%[[arg0:.*]]: memref<24x32xf16>) {
 gpu.func @subgroup_store_nd_offset_1(%dst: memref<24x32xf16>) {
   // CHECK: %[[C:.*]] = arith.constant dense<1.000000e+00> : vector<32xf16>
@@ -659,7 +658,6 @@ gpu.func @prefetch(%src: ui64) {
   gpu.return
 }
 
-
 // CHECK: gpu.func @prefetch_offset(%[[arg0:.*]]: ui64) {
 gpu.func @prefetch_offset(%src: ui64) {
   //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>

>From da7142aa4eb83340885e459edd71a4c865ba2aca Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 29 Jul 2025 17:36:14 +0000
Subject: [PATCH 13/16] git-clang-format

---
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 18 +++++++++++-------
 1 file changed, 11 insertions(+), 7 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 7c8ee7408dfd1..45a4363bd11ba 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -111,7 +111,8 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
 }
 
 static LogicalResult
-isValidGatherScatterBufferParams(Type maskTy, VectorType valueTy, int64_t chunkSize,
+isValidGatherScatterBufferParams(Type maskTy, VectorType valueTy,
+                                 int64_t chunkSize,
                                  function_ref<InFlightDiagnostic()> emitError) {
 
   if (!valueTy)
@@ -677,7 +678,8 @@ LogicalResult PrefetchOp::verify() {
       return emitOpError("Expects a scattered TensorDesc.\n");
   } else {
     if (getRankOf(getSource()) > 1)
-      return emitOpError("Expecting the source is a 1D memref or pointer (uint64_t).");
+      return emitOpError(
+          "Expecting the source is a 1D memref or pointer (uint64_t).");
   }
 
   if (!isReadHintOrNone(getL1HintAttr()))
@@ -712,7 +714,8 @@ LogicalResult LoadGatherOp::verify() {
       return emitOpError("Expects a scattered TensorDesc.\n");
   } else {
     if (getRankOf(getSource()) > 1)
-      return emitOpError("Expecting the source is a 1D memref or pointer (uint64_t).");
+      return emitOpError(
+          "Expecting the source is a 1D memref or pointer (uint64_t).");
   }
 
   if (!isReadHintOrNone(getL1HintAttr()))
@@ -731,7 +734,7 @@ LogicalResult LoadGatherOp::verify() {
   uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
   auto memTy = dyn_cast<MemRefType>(srcTy);
 
-  if (memTy && (valueTy.getElementType() != memTy.getElementType()) )
+  if (memTy && (valueTy.getElementType() != memTy.getElementType()))
     return emitError() << "Value should have the same element type as MemRef.";
 
   return isValidGatherScatterBufferParams(maskTy, valueTy, chunkSize,
@@ -760,9 +763,10 @@ LogicalResult StoreScatterOp::verify() {
       return emitOpError("Expects a scattered TensorDesc.\n");
   } else {
     if (getRankOf(getDest()) > 1)
-      return emitOpError("Expecting the dest is a 1D memref or pointer (uint64_t).");    
+      return emitOpError(
+          "Expecting the dest is a 1D memref or pointer (uint64_t).");
   }
-  
+
   if (!isWriteHintOrNone(getL1HintAttr()))
     return emitOpError("invalid l1_hint: ") << getL1HintAttr();
 
@@ -780,7 +784,7 @@ LogicalResult StoreScatterOp::verify() {
   uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
   auto memTy = dyn_cast<MemRefType>(destTy);
 
-  if (memTy && (valueTy.getElementType() != memTy.getElementType()) )
+  if (memTy && (valueTy.getElementType() != memTy.getElementType()))
     return emitError() << "Value should have the same element type as MemRef.";
 
   return isValidGatherScatterBufferParams(maskTy, valueTy, chunkSize,

>From 8b99ecc629be3c1dfb96c93ead49594dc30a47ef Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Tue, 29 Jul 2025 22:43:25 +0000
Subject: [PATCH 14/16] minor polish

---
 .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td |  3 +-
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        | 42 ++++++++-----------
 2 files changed, 20 insertions(+), 25 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index c864ce0c3d9cd..3e075de1651ab 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -863,7 +863,8 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
     }
 
     Value getTensorDesc() {
-       return getDest();
+      assert(getTensorDescType() && "Expected dest to be a TensorDescType");
+      return getDest();
     }
 
     xegpu::TensorDescType getTensorDescType() {
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 45a4363bd11ba..1b114d41b6ca5 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -673,14 +673,12 @@ LogicalResult CreateDescOp::verify() {
 LogicalResult PrefetchOp::verify() {
   auto tdescTy = getTensorDescType();
 
-  if (tdescTy) {
-    if (!tdescTy.isScattered())
-      return emitOpError("Expects a scattered TensorDesc.\n");
-  } else {
-    if (getRankOf(getSource()) > 1)
-      return emitOpError(
-          "Expecting the source is a 1D memref or pointer (uint64_t).");
-  }
+  if (tdescTy && !tdescTy.isScattered())
+    return emitOpError("Expects a scattered TensorDesc.\n");
+
+  if (!tdescTy && getRankOf(getSource()) > 1)
+    return emitOpError(
+        "Expecting the source is a 1D memref or pointer (uint64_t).");
 
   if (!isReadHintOrNone(getL1HintAttr()))
     return emitOpError("invalid l1_hint: ") << getL1HintAttr();
@@ -709,14 +707,12 @@ LogicalResult LoadGatherOp::verify() {
   auto maskTy = getMaskType();
   auto valueTy = getValueType();
 
-  if (tdescTy) {
-    if (!tdescTy.isScattered())
-      return emitOpError("Expects a scattered TensorDesc.\n");
-  } else {
-    if (getRankOf(getSource()) > 1)
-      return emitOpError(
-          "Expecting the source is a 1D memref or pointer (uint64_t).");
-  }
+  if (tdescTy && !tdescTy.isScattered())
+    return emitOpError("Expects a scattered TensorDesc.\n");
+
+  if (!tdescTy && getRankOf(getSource()) > 1)
+    return emitOpError(
+        "Expecting the source is a 1D memref or pointer (uint64_t).");
 
   if (!isReadHintOrNone(getL1HintAttr()))
     return emitOpError("invalid l1_hint: ") << getL1HintAttr();
@@ -758,14 +754,12 @@ LogicalResult StoreScatterOp::verify() {
   auto maskTy = getMaskType();
   auto valueTy = getValueType();
 
-  if (tdescTy) {
-    if (!tdescTy.isScattered())
-      return emitOpError("Expects a scattered TensorDesc.\n");
-  } else {
-    if (getRankOf(getDest()) > 1)
-      return emitOpError(
-          "Expecting the dest is a 1D memref or pointer (uint64_t).");
-  }
+  if (tdescTy && !tdescTy.isScattered())
+    return emitOpError("Expects a scattered TensorDesc.\n");
+
+  if (!tdescTy && getRankOf(getDest()) > 1)
+    return emitOpError(
+        "Expecting the dest is a 1D memref or pointer (uint64_t).");
 
   if (!isWriteHintOrNone(getL1HintAttr()))
     return emitOpError("invalid l1_hint: ") << getL1HintAttr();

>From 04306ca04fa6531f2cefe607555d64ba55fe83ef Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 30 Jul 2025 17:23:15 +0000
Subject: [PATCH 15/16] address comments

---
 .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 37 ++++++++++++++-----
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        |  2 +-
 .../Dialect/XeGPU/Transforms/XeGPUUnroll.cpp  |  6 +--
 3 files changed, 31 insertions(+), 14 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 3e075de1651ab..75b16a87e03c6 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -637,7 +637,10 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
     ```
     
     Example 2:
-    A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc. It combines "create scattered TensorTdesc" and "prefetch with scattered TensorTdesc". The source operand could be a raw pointer (uint64_t). Please refer to create_tdesc for the restriction of memref. 
+    A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
+    It combines "create scattered TensorTdesc" and "prefetch with scattered TensorTdesc".
+    The source operand could be a raw pointer (uint64_t).
+    Please refer to create_tdesc for the restriction of memref. 
     ```mlir
       %a = memref.alloc() : memref<1024xf32>
       %0 = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex>
@@ -660,8 +663,11 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
       return getSource().getType();
     }
 
-    Value getTensorDesc() {
-       return getSource();
+    TypedValue<xegpu::TensorDescType> getTensorDesc() {
+      if (auto tdescType = getTensorDescType()) {
+        return llvm::cast<TypedValue<xegpu::TensorDescType>>(getSource());
+      }
+      return TypedValue<xegpu::TensorDescType>();
     }
 
     xegpu::TensorDescType getTensorDescType() {
@@ -728,7 +734,10 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
   ```
   
   Example 4:
-  A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc. It combines "create scattered TensorTdesc" and "load with scattered TensorTdesc". The source operand could be a raw pointer (uint64_t). Please refer to create_tdesc for the restriction of memref. 
+  A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
+  It combines "create scattered TensorTdesc" and "load with scattered TensorTdesc".
+  The source operand could be a raw pointer (uint64_t). Please refer to create_tdesc
+  for the restriction of memref. 
   ```mlir
     %a = memref.alloc() : memref<1024xf32>
     %offsets = vector.step : vector<16xindex>
@@ -756,8 +765,11 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
       return getSource().getType();
     }
 
-    Value getTensorDesc() {
-       return getSource();
+    TypedValue<xegpu::TensorDescType> getTensorDesc() {
+      if (auto tdescType = getTensorDescType()) {
+        return llvm::cast<TypedValue<xegpu::TensorDescType>>(getSource());
+      }
+      return TypedValue<xegpu::TensorDescType>();
     }
 
     xegpu::TensorDescType getTensorDescType() {
@@ -833,7 +845,10 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
   ```
 
   Example 4:
-  A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc. It combines "create scattered TensorTdesc" and "store with scattered TensorTdesc". The dest operand could be a raw pointer (uint64_t). Please refer to create_tdesc for the restriction of memref. 
+  A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
+  It combines "create scattered TensorTdesc" and "store with scattered TensorTdesc".
+  The dest operand could be a raw pointer (uint64_t).
+  Please refer to create_tdesc for the restriction of memref. 
   ```mlir
     %a = memref.alloc() : memref<1024xf32>
     %val = arith.constant dense<0.0> : vector<16xf32>
@@ -862,9 +877,11 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
       return getDest().getType();
     }
 
-    Value getTensorDesc() {
-      assert(getTensorDescType() && "Expected dest to be a TensorDescType");
-      return getDest();
+    TypedValue<xegpu::TensorDescType> getTensorDesc() {
+      if (auto tdescType = getTensorDescType()) {
+        return llvm::cast<TypedValue<xegpu::TensorDescType>>(getDest());
+      }
+      return TypedValue<xegpu::TensorDescType>();
     }
 
     xegpu::TensorDescType getTensorDescType() {
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 1b114d41b6ca5..33450f3fa229e 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -708,7 +708,7 @@ LogicalResult LoadGatherOp::verify() {
   auto valueTy = getValueType();
 
   if (tdescTy && !tdescTy.isScattered())
-    return emitOpError("Expects a scattered TensorDesc.\n");
+    return emitOpError("Expects a scattered TensorDesc.");
 
   if (!tdescTy && getRankOf(getSource()) > 1)
     return emitOpError(
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index d52f7f2ac274a..9f0c074a1489d 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -485,7 +485,7 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
     xegpu::TensorDescType tdescTy = op.getTensorDescType();
 
     // TODO: handle the unstructure source case (!tdesTy)
-    if (!tdescTy || !tdescTy.isScattered())
+    if (!tdescTy || op.getOffsets())
       return failure();
 
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
@@ -548,7 +548,7 @@ struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
     xegpu::TensorDescType tdescTy = op.getTensorDescType();
 
     // TODO: handle the unstructure source case (!tdesTy)
-    if (!tdescTy || !tdescTy.isScattered())
+    if (!tdescTy || op.getOffsets())
       return failure();
 
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
@@ -578,7 +578,7 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
     xegpu::TensorDescType tdescTy = op.getTensorDescType();
 
     // TODO: handle the unstructure source case (!tdesTy)
-    if (!tdescTy || !tdescTy.isScattered())
+    if (!tdescTy || op.getOffsets())
       return failure();
 
     std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);

>From bbd6530ee6c0789af571578530dabe3e0cce7915 Mon Sep 17 00:00:00 2001
From: Jianhui Li <jian.hui.li at intel.com>
Date: Wed, 30 Jul 2025 17:53:36 +0000
Subject: [PATCH 16/16] add more invalid tests

---
 mlir/test/Dialect/XeGPU/invalid.mlir | 33 ++++++++++++++++++++++++----
 1 file changed, 29 insertions(+), 4 deletions(-)

diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 4cece4640634e..dff3ffab39ecf 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -384,6 +384,14 @@ func.func @load_gather_vc_3(%src: ui64) {
   return
 }
 
+// -----
+func.func @prefetch_offset_wi_1(%src: memref<4x4xf32>) {
+  %offsets = arith.constant dense<[0]> : vector<1xindex>
+  // expected-error at +1 {{Expecting the source is a 1D memref or pointer}}
+  xegpu.prefetch %src[%offsets]: memref<4x4xf32>, vector<1xindex>
+  return
+}
+
 // -----
 func.func @load_gather_offset_sg(%src: memref<?xf16>) {
   %offsets = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
@@ -405,7 +413,7 @@ func.func @load_gather_offset_wi(%src: ui64) {
 }
 
 // -----
-func.func @store_scatter_offset_wi(%src: memref<?xf16>) {
+func.func @store_scatter_offset_wi_1(%src: memref<?xf16>) {
   %val = arith.constant dense<2.9>: vector<4xf16>
   %offsets = arith.constant dense<[0]> : vector<1xindex>
   %mask = arith.constant dense<1>: vector<1xi1>
@@ -415,17 +423,34 @@ func.func @store_scatter_offset_wi(%src: memref<?xf16>) {
   return
 }
 
+// -----
+func.func @store_scatter_offset_wi_2(%src: memref<4x4xf16>) {
+  %val = arith.constant dense<2.9>: vector<4xf16>
+  %offsets = arith.constant dense<[0]> : vector<1xindex>
+  %mask = arith.constant dense<1>: vector<1xi1>
+  // expected-error at +1 {{Expecting the dest is a 1D memref or pointer}}
+  xegpu.store %val, %src[%offsets], %mask 
+        : vector<4xf16>, memref<4x4xf16>, vector<1xindex>, vector<1xi1>
+  return
+}
 
 // -----
-func.func @load_gather_offset_wi(%src: ui64) {
-  %val = arith.constant dense<2.9>: vector<4x2xf16>
+func.func @load_gather_offset_wi_2(%src: ui64) {
   %mask = arith.constant dense<1>: vector<1xi1>
   %offsets = arith.constant dense<[0]> : vector<1xindex>
   // expected-error at +1 {{value elements must match chunk size}}
-  %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : ui64,  vector<1xindex>, vector<1xi1> -> vector<4xf32>
+  %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : ui64,  vector<1xindex>, vector<1xi1> -> vector<4xf16>
   return
 }
 
+// -----
+func.func @load_gather_offset_wi_1(%src: memref<4x4xf32>) {
+  %mask = arith.constant dense<1>: vector<1xi1>
+  %offsets = arith.constant dense<[0]> : vector<1xindex>
+  // expected-error at +1 {{Expecting the source is a 1D memref or pointer}}
+  %2 = xegpu.load %src[%offsets], %mask <{chunk_size = 2}> : memref<4x4xf32>,  vector<1xindex>, vector<1xi1> -> vector<2xf32>
+  return
+}
 
 // -----
 func.func @store_scatter_vc_1(%src: memref<24x32xf32>) {



More information about the Mlir-commits mailing list