[Mlir-commits] [mlir] [MLIR][XeGPU] Add XeGPU scattered ops (PR #86594)

Chao Chen llvmlistbot at llvm.org
Tue Mar 26 07:50:15 PDT 2024


https://github.com/chencha3 updated https://github.com/llvm/llvm-project/pull/86594

>From 0f86b93a11d05339ac6ebd44435f513fa0e519e0 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Mon, 25 Mar 2024 22:27:58 +0000
Subject: [PATCH 1/2] Add XeGPU scattered ops

---
 mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h    |   1 +
 .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td       |  18 +-
 .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 449 +++++++++++++++---
 .../mlir/Dialect/XeGPU/IR/XeGPUTypes.td       |  21 +-
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    |  21 +
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        | 241 +++++++++-
 mlir/test/Dialect/XeGPU/XeGPUOps.mlir         |  62 +++
 7 files changed, 723 insertions(+), 90 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
index 87aabdc015fea5..eca9255ff3974b 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
@@ -12,6 +12,7 @@
 #include "mlir/Bytecode/BytecodeOpInterface.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
+#include "mlir/IR/TypeUtilities.h"
 #include "mlir/Interfaces/ShapedOpInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Interfaces/ViewLikeInterface.h"
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index cd38549f1ccf43..5a05462b3579de 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -22,14 +22,16 @@ def XeGPU_TensorDescAttr: XeGPUAttr<"TensorDesc", "tdesc_attr"> {
   let parameters = (ins
     OptionalParameter<"MemoryScopeAttr">: $memory_scope,
     OptionalParameter<"IntegerAttr", "1">: $array_length,
-    OptionalParameter<"BoolAttr", "true">: $boundary_check
+    OptionalParameter<"BoolAttr", "true">: $boundary_check,
+    OptionalParameter<"BoolAttr", "false">: $scattered
   );
 
   let builders = [
     AttrBuilder<(ins
       CArg<"xegpu::MemoryScope", "xegpu::MemoryScope::Global">:$memory_scope,
       CArg<"int", "1">:$array_length,
-      CArg<"bool", "true">: $boundary_check
+      CArg<"bool", "true">: $boundary_check,
+      CArg<"bool", "false">: $scattered
     )>
   ];
 
@@ -41,14 +43,14 @@ def XeGPU_TensorDescAttr: XeGPUAttr<"TensorDesc", "tdesc_attr"> {
 //===----------------------------------------------------------------------===//
 def XeGPU_MemoryScopeGlobal: I32EnumAttrCase<"Global", 0, "global">;
 def XeGPU_MemoryScopeShared: I32EnumAttrCase<"SLM", 1, "slm">;
-def XeGPU_MemoryScope: I32EnumAttr<"MemoryScope", 
-      "The address space of the memory the tensor descritor is created for", 
+def XeGPU_MemoryScope: I32EnumAttr<"MemoryScope",
+      "The address space of the memory the tensor descritor is created for",
       [XeGPU_MemoryScopeGlobal, XeGPU_MemoryScopeShared]> {
   let genSpecializedAttr = 0;
   let cppNamespace = "::mlir::xegpu";
 }
 
-def XeGPU_MemoryScopeAttr: 
+def XeGPU_MemoryScopeAttr:
   EnumAttr<XeGPU_Dialect, XeGPU_MemoryScope, "memory_scope"> {
     let assemblyFormat = "$value";
 }
@@ -63,15 +65,15 @@ def XeGPU_CachePolicyInvalid:       I32EnumAttrCase<"READ_INVALIDATE", 3, "read_
 def XeGPU_CachePolicyWriteBack:     I32EnumAttrCase<"WRITE_BACK", 4, "write_back">;            // valid for write only
 def XeGPU_CachePolicyWriteThrough:  I32EnumAttrCase<"WRITE_THROUGH", 5, "write_through">;      // valid for write only
 
-def XeGPU_CachePolicyEnums : I32EnumAttr<"CachePolicy", "Cache policy", 
-  [XeGPU_CachePolicyCached, XeGPU_CachePolicyUncached, 
+def XeGPU_CachePolicyEnums : I32EnumAttr<"CachePolicy", "Cache policy",
+  [XeGPU_CachePolicyCached, XeGPU_CachePolicyUncached,
    XeGPU_CachePolicyStreaming, XeGPU_CachePolicyInvalid,
    XeGPU_CachePolicyWriteBack, XeGPU_CachePolicyWriteThrough]> {
   let genSpecializedAttr = 0;
   let cppNamespace = "::mlir::xegpu";
 }
 
-def XeGPU_CacheHintAttr 
+def XeGPU_CacheHintAttr
   : EnumAttr<XeGPU_Dialect, XeGPU_CachePolicyEnums, "cache_hint"> {
     let assemblyFormat = "`<` $value `>`";
 }
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 93c56ad05b432c..0380ff83581517 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -46,36 +46,35 @@ class XeGPU_Op<string mnemonic, list<Trait> traits = []>:
 }
 
 
-def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface, 
+def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface,
                         AttrSizedOperandSegments, OffsetSizeAndStrideOpInterface]> {
 
   let summary = "Create nd-tensor descriptor operation";
   let description = [{
     The "create_nd_tdesc" operation creates a TensorDescType which represents
     a sub-view of a 2D memory region (It can be extended to support n-D memory
-    region if needed in future). Elements in the subview continuous in each 
-    dimention. It encodes the following important information for supporting 
+    region if needed in future). Elements in the subview continuous in each
+    dimention. It encodes the following important information for supporting
     Intel hardware features:
 
-    * source: an object representing (starting address/pointer of) a 2D memory region. 
+    * source: an object representing (starting address/pointer of) a 2D memory region.
         It can be either a 2D memref object, or simply a pointer represented by uint64_t type.
-        for the later case, the shape and layout information of the 2D memory region should 
-        be explicitly passed via `dynamic_shape` and `dynamic_strides` parameters.
-    * offsets: two index values represents offsets from the "source" at the each dimension 
+        for the later case, the shape and layout information of the 2D memory region should
+        be explicitly passed via `shape` and `strides` parameters.
+    * offsets: two index values represents offsets from the "source" at the each dimension
         at which the subview of the target memory will be created. It is encoded via two
-        variables, including "dynamic_offsets" and "static_offsets", such that it can
-        accept various forms, such as, operands (e.g., [%c0, %c]) and attributes (e.g., [2, 4])).
-    * shape: the shape information of the memory region pointed by the "source".  It is 
-        typically encoded via the MemRefType of the source, e.g., memref<4096x4096xf16>. 
-        But if "source" is simply a pointer represented as uint64_t type, or a memref 
-        type without shape information e.g., memref<?x?xf16>, the shape information has 
-        to be explicitly passed via the "dynamic_shape" argument. Currently "dynamic_shape" 
-        only accepts operands(e.g., [%c4096, %c4096]), not attributes(e.g., [4096, 4096]).
-    * strides: the strides of the memory region pointed by the "source". Similar to shape, 
-        it is typically encoded via the MemRefType of the source too. But if "source" is 
-        simply a pointer represented as uint64_t type, or a memref type without shape 
-        information e.g., memref<?x?xf16>, the strides information has to be explicitly 
-        passed via the "dynamic_strides" argument. And it currently only accepts operands two.
+        variables, including "offsets" and "const_offsets", such that it can
+        accept various forms, such as, operands (e.g., [%c0, %c]) and attributes (e.g., [2, 4]).
+    * shape: the shape information of the memory region pointed by the "source".  It is
+        typically encoded via the MemRefType of the source, e.g., memref<4096x4096xf16>.
+        But if "source" is simply a pointer represented as uint64_t type, or a memref
+        type without shape information e.g., memref<?x?xf16>, the shape information has
+        to be explicitly passed via the "shape" and "const_shape" arguments.
+    * strides: the strides of the memory region pointed by the "source". Similar to shape,
+        it is typically encoded via the MemRefType of the source too. But if "source" is
+        simply a pointer represented as uint64_t type, or a memref type without shape
+        information e.g., memref<?x?xf16>, the strides information has to be explicitly
+        passed via the "strides" and "const_strides" argument.
 
     Example 1 (suppose the tensor shape inferred by the compiler is 8x16):
     %0 = memref.alloc() : memref<1024x1024xf32>
@@ -96,10 +95,10 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
     %1 = xegpu.create_nd_tdesc %0[%c0, %c0], [%h, %w], [%w, %c1]: ui64 -> TensorDesc<8x16xf32>
   }];
 
-  let arguments = (ins 
-    XeGPU_BaseAddrType: $source, 
-    Variadic<Index>: $offsets, 
-    Variadic<Index>: $shape, 
+  let arguments = (ins
+    XeGPU_BaseAddrType: $source,
+    Variadic<Index>: $offsets,
+    Variadic<Index>: $shape,
     Variadic<Index>: $strides,
     DenseI64ArrayAttr: $const_offsets,
     OptionalAttr<DenseI64ArrayAttr>: $const_shape,
@@ -118,12 +117,12 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
   let hasVerifier = 1;
 
   let builders = [
-    OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType>": $source, 
+    OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType>": $source,
                    "llvm::ArrayRef<OpFoldResult>": $offsets)>,
 
-    OpBuilder<(ins "Type": $tdesc, "TypedValue<IntegerType> ": $source, 
+    OpBuilder<(ins "Type": $tdesc, "TypedValue<IntegerType> ": $source,
                    "llvm::ArrayRef<OpFoldResult>": $offsets,
-                   "llvm::ArrayRef<OpFoldResult>": $shape, 
+                   "llvm::ArrayRef<OpFoldResult>": $shape,
                    "llvm::ArrayRef<OpFoldResult>": $strides)>
   ];
 
@@ -158,41 +157,41 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
     }
 
     /// wrapper for matching with OffsetSizeAndStrideOpInterface
-    /// If source is IntegerType or `const_shape` is filled, 
+    /// If source is IntegerType or `const_shape` is filled,
     /// it will return `const_shape`, such that mixes of `shape`
-    /// and `const_shape` will be used to represent the shape of 
+    /// and `const_shape` will be used to represent the shape of
     /// source operand. They overide static shape from source memref type.
     ArrayRef<int64_t> getStaticSizes() {
       auto attr = getConstShapeAttr();
       if (getSourceType().isa<IntegerType>() || attr)
         return attr;
-      
+
       auto memrefType = getSourceType().dyn_cast<MemRefType>();
       assert(memrefType && "Incorrect use of getStaticSizes");
       return memrefType.getShape();
     }
 
     /// wrapper for matching with OffsetSizeAndStrideOpInterface
-    /// If source is IntegerType or `const_strides` is filled, it 
+    /// If source is IntegerType or `const_strides` is filled, it
     /// will return `const_strides`, such that mixes of `strides`
-    /// and `const_strides` will be used to represent the strides of 
+    /// and `const_strides` will be used to represent the strides of
     /// source operand. They overide static strides from source memref type.
     ArrayRef<int64_t> getStaticStrides() {
       auto attr = getConstStridesAttr();
       if (getSourceType().isa<IntegerType>() || attr)
         return attr;
-      
+
       auto memrefType = getSourceType().dyn_cast<MemRefType>();
       assert(memrefType && "Incorrect use of getStaticStrides");
       auto [strides, offset] = getStridesAndOffset(memrefType);
-      // reuse the storage of ConstStridesAttr since strides from 
+      // reuse the storage of ConstStridesAttr since strides from
       // memref is not persistant
       setConstStrides(strides);
       attr = getConstStridesAttr();
       return attr;
     }
 
-    /// Return the expected rank of each of the`static_offsets`, 
+    /// Return the expected rank of each of the`static_offsets`,
     /// `static_shape` and `static_strides` attributes.
     std::array<unsigned, 3> getArrayAttrMaxRanks() {
       unsigned rank;
@@ -203,8 +202,8 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
       }
       return {rank, rank, rank};
     }
-    
-    /// Return the number of leading operands before the `offsets`, 
+
+    /// Return the number of leading operands before the `offsets`,
     /// `shape` and `strides` operands.
     static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 1; }
 
@@ -213,15 +212,15 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
 }
 
 def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
-  let summary = "prefetches a nD block to cache";
+  let summary = "prefetches a n-D block to cache";
   let description = [{
-    It issues an instruction to prefetch the data from memory to each 
-    level of the cache based on their cache policy.
+    It issues an instruction to prefetch a block of data from continuous
+    memory regions to each level of the cache based on their cache policy.
 
     Example:
     ```
-      xegpu.prefetch_nd %tdesc {l1_hint = #xegpu.cache_hint<cached>, 
-                                l2_hint = #xegpu.cache_hint<cached>, 
+      xegpu.prefetch_nd %tdesc {l1_hint = #xegpu.cache_hint<cached>,
+                                l2_hint = #xegpu.cache_hint<cached>,
                                 l3_hint = #xegpu.cache_hint<cached>}
         : !xegpu.tensor_desc<8x16xf16>
     ```
@@ -232,34 +231,41 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
                        OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
-                       
-  let extraClassDeclaration = extraBaseClassDeclaration;
+
+  let extraClassDeclaration = extraBaseClassDeclaration # [{
+    xegpu::TensorDescType getTensorDescType() {
+      return getTensorDesc().getType();
+    }
+  }];
 
   let assemblyFormat = "$TensorDesc prop-dict attr-dict `:` qualified(type($TensorDesc))";
+
+  let hasVerifier = 1;
 }
 
 
-def XeGPU_LoadNdOp : XeGPU_Op<"load_nd"> {
-  let summary = "loads a n-D block from memory (represented by TensorDesc)" 
+def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [AllElementTypesMatch<["value", "TensorDesc"]>,
+                                         AllElementCountsMatch<["value", "TensorDesc"]>]> {
+  let summary = "loads a n-D block from memory (represented by TensorDesc)"
                 "to registers (represented by vector)";
   let description = [{
-    LoadNdOp essentially mimics the hardware block read instruction to read 
-    a block of data from memory to register. It takes a set of optional cache 
-    hints for each level of cache, L1, L2 and L3. If hardware does not have a 
+    LoadNdOp essentially mimics the hardware block read instruction to read
+    a block of data from memory to register. It takes a set of optional cache
+    hints for each level of cache, L1, L2 and L3. If hardware does not have a
     correspoding cache, Corresponding cache hint attribute will be masked.
-    vnni transform is an hardware feature for Intel GPU, which is used to 
-    do data packing during the load for B operand of matrix operation, if 
-    the bit width of the data type is less then 32 bits, e.g., fp16. And 
+    vnni transform is an hardware feature for Intel GPU, which is used to
+    do data packing during the load for B operand of matrix operation, if
+    the bit width of the data type is less then 32 bits, e.g., fp16. And
     transpose is another Intel hardware feature, which will do transpose
-    operation when loading the data if the bit width of the data type is 
-    fp32 or fp64. It implies that vnni and transpose cannot exit at the 
+    operation when loading the data if the bit width of the data type is
+    fp32 or fp64. It implies that vnni and transpose cannot exit at the
     same time.
 
     Example:
     ```
       xegpu.load_nd %1 {transpose = [1, 0],
-                        l1_hint = #xegpu.cache_hint<cached>, 
-                        l2_hint = #xegpu.cache_hint<uncached>, 
+                        l1_hint = #xegpu.cache_hint<cached>,
+                        l2_hint = #xegpu.cache_hint<uncached>,
                         l3_hint = #xegpu.cache_hint<streaming>}
               : !xegpu.tensor_desc<8x16xf32> -> vector<16x8xf32>
     ```
@@ -290,20 +296,21 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd"> {
   let hasVerifier = 1;
 }
 
-def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", []> {
+def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [AllShapesMatch<["value", "TensorDesc"]>,
+                                       AllElementTypesMatch<["value", "TensorDesc"]>]> {
   let summary = "stores a n-D block register region back to memory, currently only supports 2D";
 
   let description = [{
     StoreNdOp essentially mimics the hardware block write instruction io
-    write a block of data from register into the memory region as described 
-    by the TensorDesc. It takes a set of optional cache hints for each level 
-    of cache, L1, L2 and L3. If hardware does not have a correspoding cache, 
+    write a block of data from register into the memory region as described
+    by the TensorDesc. It takes a set of optional cache hints for each level
+    of cache, L1, L2 and L3. If hardware does not have a correspoding cache,
     Corresponding cache hint attribute will be masked.
 
     Example:
     ```
       xegpu.store_nd %3, %2 {l1_hint = #xegpu.cache_hint<uncached>,
-                             l2_hint = #xegpu.cache_hint<write_back>, 
+                             l2_hint = #xegpu.cache_hint<write_back>,
                              l3_hint = #xegpu.cache_hint<write_through>}
                              : vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16>
     ```
@@ -317,11 +324,327 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", []> {
                        OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
 
-  let extraClassDeclaration = extraBaseClassDeclaration;
+  let extraClassDeclaration = extraBaseClassDeclaration # [{
+    VectorType getValueType() {
+      return llvm::dyn_cast<VectorType>(getValue().getType());
+    }
+
+    xegpu::TensorDescType getTensorDescType() {
+      return getTensorDesc().getType();
+    }
+  }];
 
-  let assemblyFormat = [{$value `,` $TensorDesc prop-dict attr-dict 
+  let assemblyFormat = [{$value `,` $TensorDesc prop-dict attr-dict
                         `:` type($value) `,` qualified(type($TensorDesc))}];
   let hasVerifier = 1;
 }
 
+def XeGPU_UpdateNdOffsetOp : XeGPU_Op<"update_nd_offset",
+                [AllTypesMatch<["TensorDesc", "result"]>]> {
+  let summary = "It updates the offsets for the TensorDesc.";
+  let description = [{The op updates the offset of the given TensorDesc.
+    The offsets are relative offset to the current position in the number
+    of elements. It will result in a same type TensorDesc as the input.
+
+  example:
+  ```
+    %2 = xegpu.update_nd_offset %1, [0, 16]: !xegpu.tensor_desc<8x16xf32>
+  ```
+  }];
+
+  let arguments = (ins
+    XeGPU_TensorDesc: $TensorDesc,
+    Variadic<Index>: $offsets,
+    DenseI64ArrayAttr: $const_offsets);
+
+  let results = (outs XeGPU_TensorDesc: $result);
+
+  let extraClassDeclaration = extraBaseClassDeclaration # [{
+    xegpu::TensorDescType getTensorDescType() {
+      return getTensorDesc().getType();
+    }
+
+    SmallVector<OpFoldResult> getMixedOffsets() {
+      Builder b(getContext());
+      return getMixedValues(getConstOffsets(), getOffsets(), b);
+    }
+
+    size_t getNumOffsets() {
+      return getMixedOffsets().size();
+    }
+
+    OpFoldResult getOffset(unsigned idx) {
+      assert(idx < getNumOffsets() && "Invalid out of bound access.");
+      return getMixedOffsets()[idx];
+    }
+  }];
+
+  let assemblyFormat = [{
+    $TensorDesc `,`
+    custom<DynamicIndexList>($offsets, $const_offsets)
+    attr-dict `:` qualified(type($result))
+  }];
+
+  let hasVerifier = 1;
+}
+
+def XeGPU_CreateDescOp: XeGPU_Op<"create_tdesc", [Pure, ViewLikeOpInterface]> {
+  let summary = "create scattered tensor descriptors (TensorDesc).";
+  let description = [{
+    "create_tdesc" is similar to "create_nd_tdesc" in terms that it creates
+    a Tensor Descriptor (TensorDescType) for a memory region. While "create_nd_tdesc"
+    is for creating continious subviews, "create_tdesc" is for creating non-continious
+    (scattered) subviews, allowing each work-item in a subgroup specifying their own offset.
+    It accepts the following parameters:
+
+    * source: a 1D memref or pointer (uint64_t) represents the flattened memory object.
+    * offsets: a array containing offsets of each access point. Its size
+      is fixed to the hardware supportted subgroup size, e.g., 16 on PVC,
+      implying each element in the array corresponds to a work-item (SIMT lane)
+      in the subgroup.
+    * chunk_size: [optional attribute] indicates number of continious
+      elements accessed for each offset, default is 1.
+
+    Example 1. It assumes subgroup size is 4, and accesses a[0], a[16], a[32], a[64]
+    %a = memref.alloc() : memref<1024xf32>
+    %1 = xegpu.create_tdesc %a[0, 16, 32, 64]: memref<1024xf32> -> TensorDesc<4xf32>
+
+    Example 2. It assumes subgroup size is 4, and each workitem access 8 elements.
+               It will access totally 32 data elements: a[0:7], a[16:23], a[32:39], a[64:71]
+    %0 = memref.alloc() : memref<1024xf32>
+    %1 = xegpu.create_tdesc %0[0, 16, 32, 64] {chunk_size = 8}: memref<1024xf32> -> TensorDesc<4x8xf32>
+  }];
+
+  let arguments = (ins XeGPU_BaseAddrType: $source,
+                       Variadic<Index>: $offsets,
+                       DenseI64ArrayAttr: $const_offsets,
+                       DefaultValuedAttr<I64Attr, "1">: $chunk_size);
+  let results = (outs XeGPU_TensorDesc:$TensorDesc);
+
+  let builders = [
+    OpBuilder<(ins "xegpu::TensorDescType": $TensorDesc, "Value": $source,
+                   "llvm::ArrayRef<OpFoldResult>": $offsets,
+                   CArg<"uint32_t", "1"> : $chunk_size)>,
+  ];
+
+  let assemblyFormat = [{
+    $source
+    custom<DynamicIndexList>($offsets, $const_offsets)
+    attr-dict `:`  type($source) `->` qualified(type($TensorDesc))
+  }];
+
+  let extraClassDeclaration = extraBaseClassDeclaration # [{
+    xegpu::TensorDescType getTensorDescType() {
+      return getTensorDesc().getType();
+    }
+
+    SmallVector<OpFoldResult> getMixedOffsets() {
+      Builder b(getContext());
+      return getMixedValues(getConstOffsets(), getOffsets(), b);
+    }
+
+    size_t getNumOffsets() {
+      return getMixedOffsets().size();
+    }
+
+    mlir::Value getViewSource() { return getSource(); }
+
+    OpFoldResult getOffset(unsigned idx) {
+      assert(idx < getNumOffsets() && "Invalid out of bound access.");
+      return getMixedOffsets()[idx];
+    }
+  }];
+
+  let hasVerifier = 1;
+}
+
+def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
+  let summary = "prefetches a set of scattered data points to cache";
+
+  let description = [{
+    It issues instructions to prefetch a set of scattered data points
+    from memory to each level of the cache based on their cache policy.
+    As compared to prefetch_nd, which works on non-scattered TensorDesc,
+    it works on scattered TensorDesc instead.
+
+    Example:
+    ```
+      xegpu.prefetch %tdesc {l1_hint = #xegpu.cache_hint<cached>,
+                             l2_hint = #xegpu.cache_hint<cached>,
+                             l3_hint = #xegpu.cache_hint<cached>}
+        : !xegpu.tensor_desc<16xf16>
+    ```
+
+  }];
+
+  let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
+                       OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
+                       OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
+                       OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
+
+  let extraClassDeclaration = extraBaseClassDeclaration # [{
+    xegpu::TensorDescType getTensorDescType() {
+      return getTensorDesc().getType();
+    }
+  }];
+
+  let assemblyFormat = "$TensorDesc prop-dict attr-dict `:` qualified(type($TensorDesc))";
+
+  let hasVerifier = 1;
+}
+
+def XeGPU_LoadGatherOp : XeGPU_Op<"load", [AllRanksMatch<["value", "TensorDesc"]>,
+                                    AllElementTypesMatch<["value", "TensorDesc"]>,
+                                   AllElementCountsMatch<["value", "TensorDesc"]>]> {
+  let summary = "load a set of scattered data points from memory.";
+
+  let description = [{ It (aka. load) load data per each work-item. The output
+    describes the data being loaded at the subgroup level, so its size is
+    consistent with the number of work-items in a subgroup. When `chunk_size_per_lane`
+    attribute is larger than 1 in TensorDesc, the output vector will be 2D vector,
+    with dim-1 correspoding to the chunk size.
+
+    The mask operand masks out memory access so that it is safe to pass out-of-boundary
+    addresses/offsets as long as they are masked. It applies to slots of SIMD lanes.
+
+  Example:
+  ```
+    %2 = xegpu.load %1, %0 {transpose = [1, 0],
+                            l1_hint = #xegpu.cache_hint<cached>,
+                            l2_hint = #xegpu.cache_hint<uncached>,
+                            l3_hint = #xegpu.cache_hint<uncached>}
+          : !xegpu.tensor_desc<16xf32, #xegpu.tdesc_attr<scattered=true>>, vector<16xi1>
+            -> vector<16xf32>
+  ```
+
+  }];
+
+  let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
+                       XeGPU_MaskType: $mask,
+                       OptionalAttr<DenseI64ArrayAttr>: $transpose,
+                       OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
+                       OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
+                       OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
+  let results = (outs XeGPU_ValueType: $value);
+
+  let extraClassDeclaration = extraBaseClassDeclaration # [{
+    xegpu::TensorDescType getTensorDescType() {
+      return getTensorDesc().getType();
+    }
+
+    mlir::Type getElementType() {
+      auto type = getValue().getType();
+      return getElementTypeOrSelf(type);
+    }
+
+    Type getValueType() {
+      return getValue().getType();
+    }
+
+    Type getMaskType() {
+      return getMask().getType();
+    }
+
+  }];
+
+  let assemblyFormat = [{$TensorDesc `,` $mask prop-dict attr-dict
+      `:` qualified(type($TensorDesc)) `,` type($mask) `->` type($value)}];
+
+  let hasVerifier = 1;
+}
+
+def XeGPU_StoreScatterOp : XeGPU_Op<"store", [AllShapesMatch<["value", "TensorDesc"]>,
+                                        AllElementTypesMatch<["value", "TensorDesc"]>]> {
+  let summary = "store data to scattered memory locations.";
+  let description = [{ It (aka. store) stores data to scattered memory locations.
+  It has similar semantic to `load_gather`.
+
+  Example:
+  ```
+    %3 = xegpu.store %0, %1, %2 {l1_hint = #xegpu.cache_hint<uncached>,
+                                 l2_hint = #xegpu.cache_hint<write_back>,
+                                 l3_hint = #xegpu.cache_hint<write_through>}
+          : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.tdesc_attr<scattered=true>>, vector<16xi1>
+  ```
+  }];
+
+  let arguments = (ins
+    XeGPU_ValueType: $value,
+    XeGPU_TensorDesc: $TensorDesc,
+    XeGPU_MaskType: $mask,
+    OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
+    OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
+    OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
+
+  let extraClassDeclaration = extraBaseClassDeclaration # [{
+    xegpu::TensorDescType getTensorDescType() {
+      return getTensorDesc().getType();
+    }
+
+    Type getValueType() {
+      return getValue().getType();
+    }
+
+    Type getMaskType() {
+      return getMask().getType();
+    }
+  }];
+
+  let assemblyFormat = [{$value `,` $TensorDesc `,` $mask prop-dict attr-dict
+            `:` type($value) `,` qualified(type($TensorDesc)) `,` type($mask)}];
+
+  let hasVerifier = 1;
+}
+
+def XeGPU_UpdateOffsetOp: XeGPU_Op<"update_offset",
+          [AllTypesMatch<["TensorDesc", "result"]>]> {
+  let summary = "It updates the offsets for the given tensor descriptor";
+
+  let description = [{It behaves similar to `update_nd_offset` in terms that
+    it updates offset of a TensorDesc, and the offsets are relative offset to
+    the current position in the number of elements. However, `update_nd_offset`
+    is to update the start point of a 2D block, so its offset constains two
+    elements representing the shift in each dimension. `update_offset` is to
+    update the offset per work-item, so its offsets contains values representing
+    shifts for each work-item.
+
+    Example:
+    ```
+      %2 = xegpu.update_offset %1, [32, 32, 32, 32]
+            : !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>
+    ```
+  }];
+
+  let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
+                       Variadic<Index>: $offsets,
+                       DenseI64ArrayAttr: $const_offsets);
+  let results = (outs XeGPU_TensorDesc: $result);
+
+  let extraClassDeclaration = extraBaseClassDeclaration # [{
+    xegpu::TensorDescType getTensorDescType() {
+      return getTensorDesc().getType();
+    }
+
+    SmallVector<OpFoldResult> getMixedOffsets() {
+      Builder b(getContext());
+      return getMixedValues(getConstOffsets(), getOffsets(), b);
+    }
+
+    size_t getNumOffsets() {
+      return getMixedOffsets().size();
+    }
+
+    OpFoldResult getOffset(unsigned idx) {
+      assert(idx < getNumOffsets() && "Invalid out of bound access.");
+      return getMixedOffsets()[idx];
+    }
+  }];
+
+  let assemblyFormat = [{
+    $TensorDesc `,`
+    custom<DynamicIndexList>($offsets, $const_offsets)
+    attr-dict `:` qualified(type($TensorDesc))
+  }];
+}
+
 #endif // MLIR_DIALECT_XEGPU_IR_XEGPUOPS_TD
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 19ac1693712dd8..0c62e513bee4f3 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -63,7 +63,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
     element-type ::= float-type | integer-type | index-type
     dim-list := (static-dim-list `x`)?
     static-dim-list ::= decimal-literal `x` decimal-literal
-    attr-list = (, memory_scope = value)? (, arr_len = value)? (, boundary_check = value)?
+    attr-list = (, memory_scope = value)? (, arr_len = value)? (, boundary_check = value)? (, scattered = value)?
     ```
 
     Examples:
@@ -84,6 +84,17 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
                         "mlir::Type": $elementType,
                         OptionalParameter<"mlir::Attribute">: $encoding);
 
+  let builders = [
+    TypeBuilder<(ins
+      "llvm::ArrayRef<int64_t>": $shape, 
+      "mlir::Type": $elementType,
+      CArg<"bool", "false">: $scattered,
+      CArg<"int", "1">: $array_length,
+      CArg<"xegpu::MemoryScope", "xegpu::MemoryScope::Global">:$memory_scope,
+      CArg<"bool", "true">: $boundary_check
+    )>
+  ];
+
   let extraClassDeclaration = [{
     using TensorType::clone;
     using mlir::ShapedType::Trait<TensorDescType>::getElementTypeBitWidth;
@@ -126,6 +137,14 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
       // return default value
       return true;
     }
+
+    bool getScattered() {
+      auto attr = getEncodingAsTensorDescAttr();
+      if (attr && attr.getScattered())
+        return attr.getScattered().getValue();
+      // return default value
+      return false;
+    }
   }];
 
   let hasCustomAssemblyFormat = true;
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 0b3f4b9c9dbeae..858cda32013eae 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -32,6 +32,17 @@ void XeGPUDialect::initialize() {
 //===----------------------------------------------------------------------===//
 // XeGPU_TensorDescAttr
 //===----------------------------------------------------------------------===//
+TensorDescAttr TensorDescAttr::get(mlir::MLIRContext *context,
+                                   xegpu::MemoryScope memory_scope,
+                                   int array_length, bool boundary_check,
+                                   bool scattered) {
+  auto scopeAttr = MemoryScopeAttr::get(context, memory_scope);
+  auto lengthAttr =
+      IntegerAttr::get(IntegerType::get(context, 64), array_length);
+  auto boundaryAttr = BoolAttr::get(context, boundary_check);
+  auto scatteredAttr = BoolAttr::get(context, scattered);
+  return Base::get(context, scopeAttr, lengthAttr, boundaryAttr, scatteredAttr);
+}
 
 //===----------------------------------------------------------------------===//
 // XeGPU_TensorDescType
@@ -96,6 +107,16 @@ void TensorDescType::print(::mlir::AsmPrinter &printer) const {
   printer << ">";
 }
 
+TensorDescType TensorDescType::get(mlir::MLIRContext *context,
+                                   llvm::ArrayRef<int64_t> shape,
+                                   mlir::Type elementType, bool scattered,
+                                   int array_length, MemoryScope memory_scope,
+                                   bool boundary_check) {
+  auto attr = TensorDescAttr::get(context, memory_scope, array_length,
+                                  boundary_check, scattered);
+  return Base::get(context, shape, elementType, attr);
+}
+
 } // namespace xegpu
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 02106f221f3233..4efa46642aa78f 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -9,6 +9,9 @@
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/TypeUtilities.h"
+
+#include "llvm/Support/Debug.h"
 
 #define DEBUG_TYPE "xegpu"
 
@@ -38,6 +41,38 @@ static std::string makeString(T array, bool breakline = false) {
   return buf;
 }
 
+static std::vector<int64_t> getShapeOf(Type type) {
+  std::vector<int64_t> shape;
+  if (auto ty = llvm::dyn_cast<ShapedType>(type))
+    shape = ty.getShape().vec();
+  else
+    shape.push_back(1);
+  return shape;
+}
+
+static int64_t getRankOf(Value val) {
+  auto type = val.getType();
+  if (auto ty = llvm::dyn_cast<ShapedType>(type))
+    return ty.getRank();
+  return (int64_t)0;
+};
+
+static bool isReadHintOrNone(const CachePolicyAttr &attr) {
+  if (!attr)
+    return true;
+  auto kind = attr.getValue();
+  return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
+         kind == CachePolicy::STREAMING || kind == CachePolicy::READ_INVALIDATE;
+}
+
+static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
+  if (!attr)
+    return true;
+  auto kind = attr.getValue();
+  return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
+         kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;
+}
+
 //===----------------------------------------------------------------------===//
 // XeGPU_CreateNdDescOp
 //===----------------------------------------------------------------------===//
@@ -114,6 +149,29 @@ LogicalResult CreateNdDescOp::verify() {
     return emitOpError("TensorDesc should have the same element "
                        "type with the source if it is a memref.\n");
 
+  if (getType().getScattered())
+    return emitOpError("Expects a non-scattered TensorDesc.\n");
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_PrefetchNdOp
+//===----------------------------------------------------------------------===//
+LogicalResult PrefetchNdOp::verify() {
+  auto tdescTy = getTensorDescType();
+  if (tdescTy.getScattered())
+    return emitOpError("Expects a non-scattered TensorDesc.\n");
+
+  if (!isReadHintOrNone(getL1HintAttr()))
+    return emitOpError("invlid l1_hint: ") << getL1HintAttr();
+
+  if (!isReadHintOrNone(getL2HintAttr()))
+    return emitOpError("invlid l2_hint: ") << getL2HintAttr();
+
+  if (!isReadHintOrNone(getL3HintAttr()))
+    return emitOpError("invlid l3_hint: ") << getL3HintAttr();
+
   return success();
 }
 
@@ -125,18 +183,22 @@ LogicalResult LoadNdOp::verify() {
   auto valueTy = getType();
 
   if (tdescTy.getRank() != 2)
-    return emitOpError(
-        "The TensorDesc for LoadNdOp should be a 2D TensorDesc.");
+    return emitOpError("Expecting a 2D TensorDesc.\n");
+
+  if (tdescTy.getScattered())
+    return emitOpError("Expects a non-scattered TensorDesc.\n");
 
   if (!valueTy)
     return emitOpError("Invalid result, it should be a VectorType.\n");
 
-  auto tdescElemTy = tdescTy.getElementType();
-  auto valueElemTy = valueTy.getElementType();
+  if (!isReadHintOrNone(getL1HintAttr()))
+    return emitOpError("invlid l1_hint: ") << getL1HintAttr();
 
-  if (tdescElemTy != valueElemTy)
-    return emitOpError(
-        "Value should have the same element type as TensorDesc.");
+  if (!isReadHintOrNone(getL2HintAttr()))
+    return emitOpError("invlid l2_hint: ") << getL2HintAttr();
+
+  if (!isReadHintOrNone(getL3HintAttr()))
+    return emitOpError("invlid l3_hint: ") << getL3HintAttr();
 
   auto array_len = tdescTy.getArrayLength();
   auto tdescShape = tdescTy.getShape().vec();
@@ -174,26 +236,169 @@ LogicalResult LoadNdOp::verify() {
 // XeGPU_StoreNdOp
 //===----------------------------------------------------------------------===//
 LogicalResult StoreNdOp::verify() {
-  auto dstTy = getTensorDesc().getType();               // Tile
-  auto valTy = getValue().getType().cast<VectorType>(); // Vector
+  auto dstTy = getTensorDescType(); // Tile
+  auto valTy = getValueType();      // Vector
 
   if (dstTy.getRank() != 2)
-    return emitOpError("Expecting a 2D TensorDesc shape.\n");
+    return emitOpError("Expecting a 2D TensorDesc.\n");
+
+  if (dstTy.getScattered())
+    return emitOpError("Expects a non-scattered TensorDesc.\n");
 
   if (!valTy)
     return emitOpError("Exepcting a VectorType result.\n");
 
-  auto dstElemTy = dstTy.getElementType();
-  auto valElemTy = valTy.getElementType();
+  if (!isWriteHintOrNone(getL1HintAttr()))
+    return emitOpError("invlid l1_hint: ") << getL1HintAttr();
+
+  if (!isWriteHintOrNone(getL2HintAttr()))
+    return emitOpError("invlid l2_hint: ") << getL2HintAttr();
+
+  if (!isWriteHintOrNone(getL3HintAttr()))
+    return emitOpError("invlid l3_hint: ") << getL3HintAttr();
+
+  return success();
+}
 
-  if (dstElemTy != valElemTy) {
-    return emitOpError() << "The element type of the value should "
-                            "match the elementtype of the TensorDesc.\n";
+//===----------------------------------------------------------------------===//
+// XeGPU_UpdateNDOffsetOp
+//===----------------------------------------------------------------------===//
+LogicalResult UpdateNdOffsetOp::verify() {
+  // number of offsets specified must match the rank of the tensor descriptor
+  if (getTensorDescType().getRank() != (int64_t)getNumOffsets()) {
+    return emitOpError("Invalid number of offsets.");
   }
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_CreateDescOp
+//===----------------------------------------------------------------------===//
+void CreateDescOp::build(OpBuilder &builder, OperationState &state,
+                         TensorDescType TensorDesc, Value source,
+                         llvm::ArrayRef<OpFoldResult> offsets,
+                         uint32_t chunk_size) {
+  llvm::SmallVector<int64_t> staticOffsets;
+  llvm::SmallVector<Value> dynamicOffsets;
+  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
+  build(builder, state, TensorDesc, source, dynamicOffsets, staticOffsets,
+        chunk_size);
+}
+
+LogicalResult CreateDescOp::verify() {
+  auto tdescTy = getTensorDescType();
+  auto chunkSize = getChunkSize();
+
+  if (getRankOf(getSource()) > 2)
+    return emitOpError(
+        "Expecting the source is a 1D memref or pointer (uint64_t).");
+
+  if (!tdescTy.getScattered())
+    return emitOpError("Expects a scattered TensorDesc.\n");
+
+  std::vector<int64_t> shape({(int64_t)getNumOffsets()});
+  if (chunkSize != 1)
+    shape.push_back(chunkSize);
+
+  auto tdescShape = tdescTy.getShape();
+  if (shape != tdescShape.vec())
+    return emitOpError("Expecting the size of offsets matchs TensorDesc[0].");
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_PrefetchOp
+//===----------------------------------------------------------------------===//
+LogicalResult PrefetchOp::verify() {
+  auto tdescTy = getTensorDescType();
+  if (!tdescTy.getScattered())
+    return emitOpError("Expects a scattered TensorDesc.\n");
+
+  if (!isReadHintOrNone(getL1HintAttr()))
+    return emitOpError("invlid l1_hint: ") << getL1HintAttr();
+
+  if (!isReadHintOrNone(getL2HintAttr()))
+    return emitOpError("invlid l2_hint: ") << getL2HintAttr();
+
+  if (!isReadHintOrNone(getL3HintAttr()))
+    return emitOpError("invlid l3_hint: ") << getL3HintAttr();
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_LoadGatherOp
+//===----------------------------------------------------------------------===//
+LogicalResult LoadGatherOp::verify() {
+  auto tdescTy = getTensorDescType();
+  auto maskTy = getMaskType();
+  auto valueTy = getValueType();
+
+  if (!tdescTy.getScattered())
+    return emitOpError("Expects a scattered TensorDesc.\n");
+
+  if (!isReadHintOrNone(getL1HintAttr()))
+    return emitOpError("invlid l1_hint: ") << getL1HintAttr();
+
+  if (!isReadHintOrNone(getL2HintAttr()))
+    return emitOpError("invlid l2_hint: ") << getL2HintAttr();
+
+  if (!isReadHintOrNone(getL3HintAttr()))
+    return emitOpError("invlid l3_hint: ") << getL3HintAttr();
+
+  auto tdescElemTy = tdescTy.getElementType();
+  auto valueElemTy = getElementType();
+  if (tdescElemTy != valueElemTy)
+    return emitOpError(
+        "Value should have the same element type as TensorDesc.");
+
+  std::vector<int64_t> maskShape = getShapeOf(maskTy);
+  std::vector<int64_t> valueShape = getShapeOf(valueTy);
+  std::vector<int64_t> tdescShape = getShapeOf(tdescTy);
+
+  if (tdescShape[0] != maskShape[0])
+    return emitOpError("dim-0 of the Mask and TensorDesc should be the same.");
+
+  if (getTransposeAttr()) {
+    auto trans = getTranspose().value();
+    if (tdescShape.size() < trans.size())
+      emitWarning("Invalid transpose attr. It is ignored.");
+    else
+      transpose(trans, tdescShape);
+  }
+
+  if (valueShape != tdescShape)
+    return emitOpError("Unexpected result shape")
+           << "(Expected shape: " << makeString(tdescShape)
+           << ", Given shape: " << makeString(valueShape) << ").\n";
+
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_StoreScatterOp
+//===----------------------------------------------------------------------===//
+LogicalResult StoreScatterOp::verify() {
+  auto tdescTy = getTensorDescType();
+  if (!tdescTy.getScattered())
+    return emitOpError("Expects a scattered TensorDesc.\n");
+
+  if (!isWriteHintOrNone(getL1HintAttr()))
+    return emitOpError("invlid l1_hint: ") << getL1HintAttr();
+
+  if (!isWriteHintOrNone(getL2HintAttr()))
+    return emitOpError("invlid l2_hint: ") << getL2HintAttr();
+
+  if (!isWriteHintOrNone(getL3HintAttr()))
+    return emitOpError("invlid l3_hint: ") << getL3HintAttr();
+
+  auto maskTy = getMaskType();
+  std::vector<int64_t> maskShape = getShapeOf(maskTy);
+  std::vector<int64_t> tdescShape = getShapeOf(tdescTy);
+  if (tdescShape[0] != maskShape[0])
+    return emitOpError("dim-0 of the Mask and TensorDesc should be the same.");
 
-  if (dstTy.getShape() != valTy.getShape())
-    return emitOpError()
-           << "The result shape should match the TensorDesc shape.\n";
   return success();
 }
 
diff --git a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
index 039346adbb851c..f0945c79a94ac3 100644
--- a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
+++ b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
@@ -59,4 +59,66 @@ gpu.func @test_store_nd_vc(%dst: memref<24x32xf16>) {
   gpu.return
 }
 
+// CHECK: gpu.func @test_create_update_nd_tdesc_vc(%[[arg0:.*]]: memref<24x32xf32>) {
+gpu.func @test_create_update_nd_tdesc_vc(%src: memref<24x32xf32>) {
+  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+  // CHECK: %[[R1:.*]] = xegpu.update_nd_offset %[[REG]], [0, 16] : !xegpu.tensor_desc<8x16xf32>
+  %2 = xegpu.update_nd_offset %1, [0, 16]: !xegpu.tensor_desc<8x16xf32>
+  gpu.return
+}
+
+// CHECK: gpu.func @test_create_tdesc_vc(%[[arg0:.*]]: ui64) {
+gpu.func @test_create_tdesc_vc(%src: ui64) {
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %arg0 [0, 8, 16, 24] {chunk_size = 2 : i64} : ui64 -> !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>
+  %1 = xegpu.create_tdesc %src[0, 8, 16, 24] {chunk_size = 2} : ui64  -> !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>
+  gpu.return
+}
+
+// CHECK: gpu.func @test_prefetch_vc(%[[arg0:.*]]: ui64) {
+gpu.func @test_prefetch_vc(%src: ui64) {
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %arg0 [0, 8, 16, 24] {chunk_size = 2 : i64} : ui64 -> !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>
+  %1 = xegpu.create_tdesc %src[0, 8, 16, 24] {chunk_size = 2} : ui64  -> !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>
+  // CHECK: xegpu.prefetch %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>
+  xegpu.prefetch %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>> 
+  gpu.return
+}
+
+// CHECK: gpu.func @test_load_gather_vc(%[[arg0:.*]]: ui64) {
+gpu.func @test_load_gather_vc(%src: ui64) {
+  //CHECK: %[[cst:.*]] = arith.constant dense<true> : vector<4xi1>
+  %0 = arith.constant dense<1>: vector<4xi1>
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %arg0 [0, 8, 16, 24] {chunk_size = 2 : i64} : ui64 -> !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>
+  %1 = xegpu.create_tdesc %src[0, 8, 16, 24] {chunk_size = 2} : ui64  -> !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>
+  //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+  //CHECK-SAME: !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>, vector<4xi1> -> vector<4x2xf32>
+  %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
+        : !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>, vector<4xi1> -> vector<4x2xf32>
+  gpu.return
+}
+
+// CHECK: gpu.func @test_store_scatter_vc(%[[arg0:.*]]: ui64) {
+gpu.func @test_store_scatter_vc(%src: ui64) {
+  //CHECK: %[[c0:.*]] = arith.constant dense<true> : vector<4xi1>
+  %0 = arith.constant dense<1>: vector<4xi1>
+  //CHECK: %[[c1:.*]] = arith.constant dense<2.900000e+00> : vector<4x2xf32>
+  %1 = arith.constant dense<2.9>: vector<4x2xf32>
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %arg0 [0, 8, 16, 24] {chunk_size = 2 : i64} : ui64 -> !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>
+  %2 = xegpu.create_tdesc %src[0, 8, 16, 24] {chunk_size = 2} : ui64  -> !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>
+  //CHECK: xegpu.store %[[c1]], %[[R0]], %[[c0]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
+  //CHECK-SAME: vector<4x2xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>, vector<4xi1>
+  xegpu.store %1, %2, %0 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
+        : vector<4x2xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>, vector<4xi1>
+  gpu.return
+}
+
+// CHECK: gpu.func @test_create_update_tdesc_vc(%[[arg0:.*]]: ui64) {
+gpu.func @test_create_update_tdesc_vc(%src: ui64) {
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %arg0 [0, 8, 16, 24] {chunk_size = 2 : i64} : ui64 -> !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>
+  %1 = xegpu.create_tdesc %src[0, 8, 16, 24] {chunk_size = 2} : ui64  -> !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>
+  //CHECK: %[[R1:.*]] = xegpu.update_offset %[[R0]], [32, 32, 32, 32] : !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>
+  %2 = xegpu.update_offset %1, [32, 32, 32, 32] : !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>
+  gpu.return
+}
+
 }
\ No newline at end of file

>From 89148e9f58d02795550f735bf350b63036cb442c Mon Sep 17 00:00:00 2001
From: Chao Chen <116223022+chencha3 at users.noreply.github.com>
Date: Tue, 26 Mar 2024 09:50:09 -0500
Subject: [PATCH 2/2] Update mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Co-authored-by: Mehdi Amini <joker.eph at gmail.com>
---
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 4efa46642aa78f..dc18d8c9b40366 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -41,7 +41,7 @@ static std::string makeString(T array, bool breakline = false) {
   return buf;
 }
 
-static std::vector<int64_t> getShapeOf(Type type) {
+static SmallVector<int64_t> getShapeOf(Type type) {
   std::vector<int64_t> shape;
   if (auto ty = llvm::dyn_cast<ShapedType>(type))
     shape = ty.getShape().vec();



More information about the Mlir-commits mailing list