[Mlir-commits] [mlir] update tdesc_attr (PR #109144)

Chao Chen llvmlistbot at llvm.org
Wed Sep 18 13:28:58 PDT 2024


https://github.com/chencha3 updated https://github.com/llvm/llvm-project/pull/109144

>From 2cd064a1f542de38cd42eb29e2d4bf5650282763 Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Tue, 17 Sep 2024 18:58:18 +0000
Subject: [PATCH 1/3] update tdesc_attr

---
 .../mlir/Dialect/XeGPU/IR/XeGPUAttrs.td       | 46 +++++++++++---
 .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 19 ++----
 .../mlir/Dialect/XeGPU/IR/XeGPUTypes.td       | 63 ++++++++++++-------
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    | 40 ++++++++----
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        | 31 ++++-----
 5 files changed, 122 insertions(+), 77 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index f3ca09a6a68ea8..6ffb4eb3c60f2b 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -19,9 +19,15 @@ class XeGPUAttr<string name, string attrMnemonic, list<Trait> traits = [],
   let mnemonic = attrMnemonic;
 }
 
-def XeGPU_TensorDescAttr: XeGPUAttr<"TensorDesc", "tdesc_attr"> {
+class XeGPU_TensorDescAttr<string name, string attrMnemonic, list<Trait> traits = [],
+                         string baseCppClass = "::mlir::Attribute">
+    : XeGPUAttr<name, attrMnemonic, traits, baseCppClass> {
+  let assemblyFormat = "`<` struct(params) `>`";
+}
+
+def XeGPU_BlockTensorDescAttr: XeGPU_TensorDescAttr<"BlockTensorDesc", "block_tdesc_attr"> {
   let summary = [{a composite attribute for `TensorDescType`}];
-  let description = [{`TensorDescAttr` (or `tdesc_attr`) is a composite
+  let description = [{`BlockTensorDesc` (or `block_tdesc_attr`) is a composite
     attribute defined for `TensorDescType` for describing following
     properties of a `TensorDesc`.
     1. `memory_scope`: It describes where the data block described by the
@@ -33,29 +39,49 @@ def XeGPU_TensorDescAttr: XeGPUAttr<"TensorDesc", "tdesc_attr"> {
         8x32. Its default value is 1.
     3. `boundary_check`: It is used to indicates the hardware whether to do
         out-of-boundary check. The default value is true.
-    4. `scattered`: It is used to differenciate TensorDescs created from
-       `create_nd_tdesc` vs from `create_tdesc`.
   }];
 
   let parameters = (ins
     OptionalParameter<"MemoryScopeAttr">: $memory_scope,
     OptionalParameter<"IntegerAttr", "1">: $array_length,
-    OptionalParameter<"BoolAttr", "true">: $boundary_check,
-    OptionalParameter<"BoolAttr", "false">: $scattered
+    OptionalParameter<"BoolAttr", "true">: $boundary_check
   );
 
   let builders = [
     AttrBuilder<(ins
       CArg<"xegpu::MemoryScope", "xegpu::MemoryScope::Global">:$memory_scope,
       CArg<"int", "1">:$array_length,
-      CArg<"bool", "true">: $boundary_check,
-      CArg<"bool", "false">: $scattered
+      CArg<"bool", "true">: $boundary_check
     )>
   ];
 
-  let assemblyFormat = "`<` struct(params) `>`";
 }
 
+def XeGPU_ScatterTensorDescAttr: XeGPU_TensorDescAttr<"ScatterTensorDesc", "scatter_tdesc_attr"> {
+  let summary = [{a composite attribute for `TensorDescType`}];
+  let description = [{`ScatterTensorDesc` (or `scatter_tdesc_attr`) is a composite
+    attribute defined for `TensorDescType` for describing following
+    properties of a `TensorDesc`.
+    1. `memory_scope`: It describes where the data block described by the
+        TensorDesc is located, `Global` device memory or `Shared` local memory.
+        It is default to `Global`.
+    2.  `chunk_size`: indicates number of continious elements accessed for each
+        offset, default is 1. It is used with `scattered` attr only.
+  }];
+
+  let parameters = (ins
+    OptionalParameter<"MemoryScopeAttr">: $memory_scope,
+    OptionalParameter<"IntegerAttr", "1">: $chunk_size
+  );
+
+  let builders = [
+    AttrBuilder<(ins
+      CArg<"xegpu::MemoryScope", "xegpu::MemoryScope::Global">:$memory_scope,
+      CArg<"int", "1">: $chunk_size
+    )>
+  ];
+ }
+
 //===----------------------------------------------------------------------===//
 // XeGPU Memory Scope Enums.
 //===----------------------------------------------------------------------===//
@@ -116,4 +142,4 @@ def XeGPU_FenceScopeAttr:
     let assemblyFormat = "$value";
 }
 
-#endif // MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD
\ No newline at end of file
+#endif // MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index c32c7541c39791..13a0bff5de1a6e 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -411,42 +411,33 @@ def XeGPU_CreateDescOp: XeGPU_Op<"create_tdesc", [Pure, ViewLikeOpInterface]> {
       is fixed to the hardware supportted subgroup size, e.g., 16 on PVC,
       implying each element in the array corresponds to a work-item (SIMT lane)
       in the subgroup.
-    * chunk_size: [optional attribute] indicates number of continious
-      elements accessed for each offset, default is 1.
 
     Example 1. It assumes subgroup size is 4, and accesses a[0], a[16], a[32], a[64]
     ```mlir
     %a = memref.alloc() : memref<1024xf32>
-    %1 = xegpu.create_tdesc %a[0, 16, 32, 64]: memref<1024xf32> -> TensorDesc<4xf32>
+    %1 = xegpu.create_tdesc %a[0, 16, 32, 64]: memref<1024xf32> -> TensorDesc<4xf32, chunk_size_per_lane = 1>
     ```
 
     Example 2. It assumes subgroup size is 4, and each workitem access 8 elements.
                It will access totally 32 data elements: a[0:7], a[16:23], a[32:39], a[64:71]
     ```mlir
     %0 = memref.alloc() : memref<1024xf32>
-    %1 = xegpu.create_tdesc %0[0, 16, 32, 64] {chunk_size = 8}: memref<1024xf32> -> TensorDesc<4x8xf32>
+    %1 = xegpu.create_tdesc %0[0, 16, 32, 64] : memref<1024xf32> -> TensorDesc<4x8xf32, chunk_size_per_lane = 8>
     ```
 
     Example 3. It is similar to Example 2, but there is some overlaps among workitems.
                It accesses: a[0:7], a[4:11], a[8:15], a[12:19]
     ```mlir
     %0 = memref.alloc() : memref<1024xf32>
-    %1 = xegpu.create_tdesc %0[0, 4, 8, 12] {chunk_size = 8}: memref<1024xf32> -> TensorDesc<4x8xf32>
+    %1 = xegpu.create_tdesc %0[0, 4, 8, 12] : memref<1024xf32> -> TensorDesc<4x8xf32, chunk_size_per_lane = 8>>
     ```
   }];
 
   let arguments = (ins XeGPU_BaseAddrType: $source,
                        Variadic<Index>: $offsets,
-                       DenseI64ArrayAttr: $const_offsets,
-                       DefaultValuedAttr<I64Attr, "1">: $chunk_size);
+                       DenseI64ArrayAttr: $const_offsets);
   let results = (outs XeGPU_TensorDesc:$TensorDesc);
 
-  let builders = [
-    OpBuilder<(ins "xegpu::TensorDescType": $TensorDesc, "Value": $source,
-                   "llvm::ArrayRef<OpFoldResult>": $offsets,
-                   CArg<"uint32_t", "1"> : $chunk_size)>,
-  ];
-
   let assemblyFormat = [{
     $source
     custom<DynamicIndexList>($offsets, $const_offsets)
@@ -723,7 +714,7 @@ def XeGPU_DpasOp : XeGPU_Op<"dpas", [Pure, AllElementTypesMatch<["lhs", "rhs"]>]
 
 def XeGPU_AtomicRMWOp: XeGPU_Op<"atomic_rmw", [Pure,
       AllElementTypesMatch<["tensorDesc", "value", "result"]>,
-      AllShapesMatch<["tensorDesc", "mask", "value", "result"]>]> {
+      AllShapesMatch<["tensorDesc", "value", "result"]>]> {
   let summary = "Atomic ready-modify-write operation on the TensorDesc. ";
 
   let description = [{
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 9f101a71697b56..8b22baf365afa2 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -88,11 +88,14 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
     TypeBuilderWithInferredContext<(ins
       "llvm::ArrayRef<int64_t>": $shape,
       "mlir::Type": $elementType,
-      CArg<"bool", "false">: $scattered,
       CArg<"int", "1">: $array_length,
-      CArg<"xegpu::MemoryScope", "xegpu::MemoryScope::Global">:$memory_scope,
-      CArg<"bool", "true">: $boundary_check
-    )>
+      CArg<"bool", "true">: $boundary_check,
+      CArg<"xegpu::MemoryScope", "xegpu::MemoryScope::Global">:$memory_scope)>,
+    TypeBuilderWithInferredContext<(ins
+      "llvm::ArrayRef<int64_t>": $shape,
+      "mlir::Type": $elementType,
+      CArg<"int", "1">: $chunk_size,
+      CArg<"xegpu::MemoryScope", "xegpu::MemoryScope::Global">:$memory_scope)>
   ];
 
   let extraClassDeclaration = [{
@@ -110,40 +113,58 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
       return llvm::cast<TensorDescType>(cloneWith(getShape(), elementType));
     }
 
-    TensorDescAttr getEncodingAsTensorDescAttr() const {
-      return llvm::dyn_cast_if_present<TensorDescAttr>(getEncoding());
+    BlockTensorDescAttr getEncodingAsBlockTensorDescAttr() const {
+      return llvm::dyn_cast_if_present<BlockTensorDescAttr>(getEncoding());
+    }
+
+    ScatterTensorDescAttr getEncodingAsScatterTensorDescAttr() const {
+      return llvm::dyn_cast_if_present<ScatterTensorDescAttr>(getEncoding());
     }
 
     xegpu::MemoryScope getMemoryScope() const {
-      auto attr = getEncodingAsTensorDescAttr();
-      if (attr && attr.getMemoryScope())
-        return attr.getMemoryScope().getValue();
+      auto block_attr = getEncodingAsBlockTensorDescAttr();
+      if (block_attr && block_attr.getMemoryScope())
+        return block_attr.getMemoryScope().getValue();
+
+      auto scatter_attr = getEncodingAsScatterTensorDescAttr();
+      if (scatter_attr && scatter_attr.getMemoryScope())
+        return scatter_attr.getMemoryScope().getValue();
+
       // return default value
       return MemoryScope::Global;
     }
 
     int getArrayLength() {
-      auto attr = getEncodingAsTensorDescAttr();
-      if (attr && attr.getArrayLength())
-        return attr.getArrayLength().getInt();
+      auto attr = getEncoding();
+      auto block_attr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(attr);
+      assert((!attr || block_attr) && "invalid on non BlockTensorDescAttr.");
+      if (block_attr && block_attr.getArrayLength())
+        return block_attr.getArrayLength().getInt();
       // return default value
       return 1;
     }
 
     bool getBoundaryCheck() {
-      auto attr = getEncodingAsTensorDescAttr();
-      if (attr && attr.getBoundaryCheck())
-        return attr.getBoundaryCheck().getValue();
+      auto attr = getEncoding();
+      auto block_attr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(attr);
+      assert((!attr || block_attr) && "invalid on non BlockTensorDescAttr.");
+      if (block_attr && block_attr.getBoundaryCheck())
+        return block_attr.getBoundaryCheck().getValue();
       // return default value
       return true;
     }
 
-    bool getScattered() {
-      auto attr = getEncodingAsTensorDescAttr();
-      if (attr && attr.getScattered())
-        return attr.getScattered().getValue();
-      // return default value
-      return false;
+    bool isScattered() {
+      return bool(getEncodingAsScatterTensorDescAttr());
+    }
+
+    int getChunkSize() {
+      auto attr = getEncoding();
+      auto scatter_attr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(attr);
+      assert((!attr || scatter_attr) && "invalid on non ScatterTensorDescAttr.");
+      if (scatter_attr && scatter_attr.getChunkSize())
+        return scatter_attr.getChunkSize().getInt();
+      return 1;
     }
   }];
 
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 24719fe748fe4f..0eab601bbaac4c 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -30,18 +30,28 @@ void XeGPUDialect::initialize() {
 }
 
 //===----------------------------------------------------------------------===//
-// XeGPU_TensorDescAttr
+// XeGPU_BlockTensorDescAttr
 //===----------------------------------------------------------------------===//
-TensorDescAttr TensorDescAttr::get(mlir::MLIRContext *context,
-                                   xegpu::MemoryScope memory_scope,
-                                   int array_length, bool boundary_check,
-                                   bool scattered) {
+BlockTensorDescAttr BlockTensorDescAttr::get(mlir::MLIRContext *context,
+                                        xegpu::MemoryScope memory_scope,
+                                        int array_length, bool boundary_check) {
   auto scopeAttr = MemoryScopeAttr::get(context, memory_scope);
   auto lengthAttr =
       IntegerAttr::get(IntegerType::get(context, 64), array_length);
   auto boundaryAttr = BoolAttr::get(context, boundary_check);
-  auto scatteredAttr = BoolAttr::get(context, scattered);
-  return Base::get(context, scopeAttr, lengthAttr, boundaryAttr, scatteredAttr);
+  return Base::get(context, scopeAttr, lengthAttr, boundaryAttr);
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_ScatterTensorDescAttr
+//===----------------------------------------------------------------------===//
+ScatterTensorDescAttr ScatterTensorDescAttr::get(mlir::MLIRContext *context,
+                                            xegpu::MemoryScope memory_scope,
+                                            int chunk_size) {
+  auto scopeAttr = MemoryScopeAttr::get(context, memory_scope);
+  auto chunkSizeAttr =
+      IntegerAttr::get(IntegerType::get(context, 64), chunk_size);
+  return Base::get(context, scopeAttr, chunkSizeAttr);
 }
 
 //===----------------------------------------------------------------------===//
@@ -108,12 +118,18 @@ void TensorDescType::print(::mlir::AsmPrinter &printer) const {
 }
 
 TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
-                                   mlir::Type elementType, bool scattered,
-                                   int array_length, MemoryScope memory_scope,
-                                   bool boundary_check) {
+                                   mlir::Type elementType, int array_length,
+                                   bool boundary_check, MemoryScope memory_scope) {
+  auto context = elementType.getContext();
+  auto attr = BlockTensorDescAttr::get(context, memory_scope, array_length, boundary_check);
+  return Base::get(context, shape, elementType, attr);
+}
+
+TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
+                                   mlir::Type elementType, int chunk_size,
+                                   MemoryScope memory_scope) {
   auto context = elementType.getContext();
-  auto attr = TensorDescAttr::get(context, memory_scope, array_length,
-                                  boundary_check, scattered);
+  auto attr = ScatterTensorDescAttr::get(context, memory_scope, chunk_size);
   return Base::get(context, shape, elementType, attr);
 }
 
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 8e185b8d2586d9..ee3834bd0d9cc6 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -153,7 +153,7 @@ LogicalResult CreateNdDescOp::verify() {
     return emitOpError("TensorDesc should have the same element "
                        "type with the source if it is a memref.\n");
 
-  if (getType().getScattered())
+  if (getType().isScattered())
     return emitOpError("Expects a non-scattered TensorDesc.\n");
 
   return success();
@@ -164,7 +164,7 @@ LogicalResult CreateNdDescOp::verify() {
 //===----------------------------------------------------------------------===//
 LogicalResult PrefetchNdOp::verify() {
   auto tdescTy = getTensorDescType();
-  if (tdescTy.getScattered())
+  if (tdescTy.isScattered())
     return emitOpError("Expects a non-scattered TensorDesc.\n");
 
   if (!isReadHintOrNone(getL1HintAttr()))
@@ -189,7 +189,7 @@ LogicalResult LoadNdOp::verify() {
   if (tdescTy.getRank() > 2)
     return emitOpError("Expecting a 1D/2D TensorDesc.\n");
 
-  if (tdescTy.getScattered())
+  if (tdescTy.isScattered())
     return emitOpError("Expects a non-scattered TensorDesc.\n");
 
   if (!valueTy)
@@ -257,7 +257,7 @@ LogicalResult StoreNdOp::verify() {
   if (dstTy.getRank() > 2)
     return emitOpError("Expecting a 1D/2D TensorDesc.\n");
 
-  if (dstTy.getScattered())
+  if (dstTy.isScattered())
     return emitOpError("Expects a non-scattered TensorDesc.\n");
 
   if (!valTy)
@@ -280,7 +280,7 @@ LogicalResult StoreNdOp::verify() {
 //===----------------------------------------------------------------------===//
 LogicalResult UpdateNdOffsetOp::verify() {
   auto ty = getTensorDescType();
-  if (ty.getScattered())
+  if (ty.isScattered())
     return emitOpError("Expects a non-scattered TensorDesc.\n");
 
   // number of offsets specified must match the rank of the tensor descriptor
@@ -293,28 +293,19 @@ LogicalResult UpdateNdOffsetOp::verify() {
 //===----------------------------------------------------------------------===//
 // XeGPU_CreateDescOp
 //===----------------------------------------------------------------------===//
-void CreateDescOp::build(OpBuilder &builder, OperationState &state,
-                         TensorDescType TensorDesc, Value source,
-                         llvm::ArrayRef<OpFoldResult> offsets,
-                         uint32_t chunk_size) {
-  llvm::SmallVector<int64_t> staticOffsets;
-  llvm::SmallVector<Value> dynamicOffsets;
-  dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
-  build(builder, state, TensorDesc, source, dynamicOffsets, staticOffsets,
-        chunk_size);
-}
 
 LogicalResult CreateDescOp::verify() {
   auto tdescTy = getTensorDescType();
-  auto chunkSize = getChunkSize();
 
   if (getRankOf(getSource()) > 1)
     return emitOpError(
         "Expecting the source is a 1D memref or pointer (uint64_t).");
 
-  if (!tdescTy.getScattered())
+  if (!tdescTy.isScattered())
     return emitOpError("Expects a scattered TensorDesc.\n");
 
+  auto chunkSize = tdescTy.getChunkSize();
+
   SmallVector<int64_t> shape({(int64_t)getNumOffsets()});
   if (chunkSize != 1)
     shape.push_back(chunkSize);
@@ -332,7 +323,7 @@ LogicalResult CreateDescOp::verify() {
 //===----------------------------------------------------------------------===//
 LogicalResult PrefetchOp::verify() {
   auto tdescTy = getTensorDescType();
-  if (!tdescTy.getScattered())
+  if (!tdescTy.isScattered())
     return emitOpError("Expects a scattered TensorDesc.\n");
 
   if (!isReadHintOrNone(getL1HintAttr()))
@@ -355,7 +346,7 @@ LogicalResult LoadGatherOp::verify() {
   auto maskTy = getMaskType();
   auto valueTy = getValueType();
 
-  if (!tdescTy.getScattered())
+  if (!tdescTy.isScattered())
     return emitOpError("Expects a scattered TensorDesc.\n");
 
   if (!isReadHintOrNone(getL1HintAttr()))
@@ -401,7 +392,7 @@ LogicalResult LoadGatherOp::verify() {
 //===----------------------------------------------------------------------===//
 LogicalResult StoreScatterOp::verify() {
   auto tdescTy = getTensorDescType();
-  if (!tdescTy.getScattered())
+  if (!tdescTy.isScattered())
     return emitOpError("Expects a scattered TensorDesc.\n");
 
   if (!isWriteHintOrNone(getL1HintAttr()))

>From 24adc84d0a42f5e7712291ef3a886fa5de044f0f Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Wed, 18 Sep 2024 13:55:40 +0000
Subject: [PATCH 2/3] update load_gather and store_scatter

---
 .../include/mlir/Dialect/XeGPU/IR/XeGPUOps.td | 40 ++++++++-----
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp    |  1 +
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp        | 48 ++++++++++++---
 mlir/test/Dialect/XeGPU/XeGPUOps.mlir         | 60 +++++++++----------
 mlir/test/Dialect/XeGPU/invalid.mlir          | 54 ++++++++---------
 5 files changed, 124 insertions(+), 79 deletions(-)

diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index 13a0bff5de1a6e..1d379460a48234 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -412,24 +412,28 @@ def XeGPU_CreateDescOp: XeGPU_Op<"create_tdesc", [Pure, ViewLikeOpInterface]> {
       implying each element in the array corresponds to a work-item (SIMT lane)
       in the subgroup.
 
+    The first dimension of the result TensorDesc corresponds to work-items, so it should
+    match the dimension of offsets. It may also has a second dimension corresponding to
+    the chunk_size if the chunk size is larger than 1.
+
     Example 1. It assumes subgroup size is 4, and accesses a[0], a[16], a[32], a[64]
     ```mlir
     %a = memref.alloc() : memref<1024xf32>
-    %1 = xegpu.create_tdesc %a[0, 16, 32, 64]: memref<1024xf32> -> TensorDesc<4xf32, chunk_size_per_lane = 1>
+    %1 = xegpu.create_tdesc %a[0, 16, 32, 64]: memref<1024xf32> -> TensorDesc<4xf32>
     ```
 
     Example 2. It assumes subgroup size is 4, and each workitem access 8 elements.
                It will access totally 32 data elements: a[0:7], a[16:23], a[32:39], a[64:71]
     ```mlir
     %0 = memref.alloc() : memref<1024xf32>
-    %1 = xegpu.create_tdesc %0[0, 16, 32, 64] : memref<1024xf32> -> TensorDesc<4x8xf32, chunk_size_per_lane = 8>
+    %1 = xegpu.create_tdesc %0[0, 16, 32, 64] : memref<1024xf32> -> TensorDesc<4x8xf32, chunk_size = 8>
     ```
 
     Example 3. It is similar to Example 2, but there is some overlaps among workitems.
                It accesses: a[0:7], a[4:11], a[8:15], a[12:19]
     ```mlir
     %0 = memref.alloc() : memref<1024xf32>
-    %1 = xegpu.create_tdesc %0[0, 4, 8, 12] : memref<1024xf32> -> TensorDesc<4x8xf32, chunk_size_per_lane = 8>>
+    %1 = xegpu.create_tdesc %0[0, 4, 8, 12] : memref<1024xf32> -> TensorDesc<4x8xf32, chunk_size = 8>>
     ```
   }];
 
@@ -511,28 +515,31 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [AllRanksMatch<["value", "TensorDesc"]
 
   let description = [{ It (aka. load) load data per each work-item. The output
     describes the data being loaded at the subgroup level, so its size is
-    consistent with the number of work-items in a subgroup. When `chunk_size_per_lane`
-    attribute is larger than 1 in TensorDesc, the output vector will be 2D vector,
-    with dim-1 correspoding to the chunk size.
+    consistent with the number of work-items in a subgroup. When the chunk size
+    is larger than 2, the output vector is a 2D vector, with dim-1 correspoding
+    to work-items, and dim-0 corresponding to the chunk_size loaded by each work-item.
+    Specially, there is a transpose effect on the result (as compared to the TensorDesc)
+    due to the hardware implementation. Therefore, a transpose attribute is introduced
+    on purpose, making sure users are aware of this implicit transformation.
 
     The mask operand masks out memory access so that it is safe to pass out-of-boundary
     addresses/offsets as long as they are masked. It applies to slots of SIMD lanes.
 
   Example:
   ```mlir
-    %2 = xegpu.load %1, %0 {transpose = [1, 0],
+    %2 = xegpu.load %1, %0 {transpose,
                             l1_hint = #xegpu.cache_hint<cached>,
                             l2_hint = #xegpu.cache_hint<uncached>,
                             l3_hint = #xegpu.cache_hint<uncached>}
-          : !xegpu.tensor_desc<16xf32, #xegpu.tdesc_attr<scattered=true>>, vector<16xi1>
-            -> vector<16xf32>
+          : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<memory_scope=global>>,
+            vector<16xi1> -> vector<16xf32>
   ```
 
   }];
 
   let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
                        XeGPU_MaskType: $mask,
-                       OptionalAttr<DenseI64ArrayAttr>: $transpose,
+                       OptionalAttr<UnitAttr>: $transpose,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
@@ -564,11 +571,15 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [AllRanksMatch<["value", "TensorDesc"]
   let hasVerifier = 1;
 }
 
-def XeGPU_StoreScatterOp : XeGPU_Op<"store", [AllShapesMatch<["value", "TensorDesc"]>,
-                                        AllElementTypesMatch<["value", "TensorDesc"]>]> {
+def XeGPU_StoreScatterOp : XeGPU_Op<"store", [AllElementCountsMatch<["value", "TensorDesc"]>,
+                                              AllElementTypesMatch<["value", "TensorDesc"]>]> {
   let summary = "store data to scattered memory locations.";
-  let description = [{ It (aka. store) stores data to scattered memory locations.
-  It has similar semantic to `load_gather`.
+  let description = [{ It (aka. store) stores data to scattered memory locations. The value is
+  typically a 1D vector. But when the chunk size of the TensorDesc is larger than 1, it will be
+  a 2D vector instead. For the later case, dim-1 of the value correspods to the simd lanes
+  and the dim-0 of the value corresponds to the chunk_size stored per lane. So `store_scatter`
+  has transpose effect, which is similar to `load_gather`. Therefore, a transpose attribute is
+  introduced on purpose, making sure users are aware of this implicit transformation.
 
   Example:
   ```mlir
@@ -583,6 +594,7 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [AllShapesMatch<["value", "TensorDe
     XeGPU_ValueType: $value,
     XeGPU_TensorDesc: $TensorDesc,
     XeGPU_MaskType: $mask,
+    OptionalAttr<UnitAttr>: $transpose,
     OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
     OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
     OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 0eab601bbaac4c..555c232ff1f063 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -57,6 +57,7 @@ ScatterTensorDescAttr ScatterTensorDescAttr::get(mlir::MLIRContext *context,
 //===----------------------------------------------------------------------===//
 // XeGPU_TensorDescType
 //===----------------------------------------------------------------------===//
+
 mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
   llvm::SmallVector<int64_t> shape;
   mlir::Type elementType;
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index ee3834bd0d9cc6..0da38df90fdbd4 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -229,8 +229,8 @@ LogicalResult LoadNdOp::verify() {
       tdescShape[axis] /= vnni_factor;
       tdescShape.push_back(vnni_factor);
     } else {
-      return emitWarning("Invalid Packed Attr. It is ignored (available for 2D "
-                         "TensorDesc only).");
+      emitWarning("Invalid Packed Attr. It is ignored (available for 2D "
+                  "TensorDesc only).");
     }
   }
 
@@ -306,6 +306,26 @@ LogicalResult CreateDescOp::verify() {
 
   auto chunkSize = tdescTy.getChunkSize();
 
+  // check chunk_size
+  llvm::SmallVector<int64_t> supportedChunkSizes = {1, 2, 3, 4, 8, 16, 32, 64, 128, 256};
+  if (!llvm::is_contained(supportedChunkSizes, chunkSize))
+    return emitOpError("Invalid chunk_size. Supported values are 1, 2, 3, 4, 8, 16, 32, 64, 128, or 256.");
+
+  // check total size
+  auto elemBits = tdescTy.getElementType().getIntOrFloatBitWidth();
+  auto bitsPerLane = elemBits * chunkSize;
+  if (chunkSize > 1 && bitsPerLane % 32) {
+    // For 8-bit and 16-bit data, the hardware only supports chunk size of 1.
+    // For 32-bit data, the hardware can support larger larger chunk size. So
+    // we can bitcast 8-bit/16-bit data to 32-bit data for better performance.
+    // But this requires the total size is 32 bit aligned to make the optimization work.
+    return emitOpError("access size (chunk_size * sizeof(elemTy)) should be 32-bit aligned.");
+  }
+
+  auto lscConstraints = 512 * 8; // each access is upto 512 bytes.
+  if (elemBits * tdescTy.getNumElements() > lscConstraints)
+    return emitOpError("total access size (simd_lanes * chunk_size * sizeof(elemTy)) is upto 512 bytes.");
+
   SmallVector<int64_t> shape({(int64_t)getNumOffsets()});
   if (chunkSize != 1)
     shape.push_back(chunkSize);
@@ -371,14 +391,13 @@ LogicalResult LoadGatherOp::verify() {
   if (tdescShape[0] != maskShape[0])
     return emitOpError("dim-0 of the Mask and TensorDesc should be the same.");
 
-  if (getTransposeAttr()) {
-    auto trans = getTranspose().value();
-    if (tdescShape.size() < trans.size())
-      emitWarning("Invalid transpose attr. It is ignored.");
-    else
-      transpose(trans, tdescShape);
+  if (tdescTy.getRank() == 2) {
+    if (!getTransposeAttr())
+      return emitOpError("load_gather has to be transposed.");
+    transpose({1, 0}, tdescShape);
   }
 
+
   if (valueShape != tdescShape)
     return emitOpError("Unexpected result shape")
            << "(Expected shape: " << makeString(tdescShape)
@@ -405,11 +424,24 @@ LogicalResult StoreScatterOp::verify() {
     return emitOpError("invlid l3_hint: ") << getL3HintAttr();
 
   auto maskTy = getMaskType();
+  auto valueTy = getValueType();
   auto maskShape = getShapeOf(maskTy);
   auto tdescShape = getShapeOf(tdescTy);
+  auto valueShape = getShapeOf(valueTy);
   if (tdescShape[0] != maskShape[0])
     return emitOpError("dim-0 of the Mask and TensorDesc should be the same.");
 
+  if (tdescTy.getRank() == 2) {
+    if (!getTransposeAttr())
+      return emitOpError("load_gather has to be transposed.");
+    transpose({1, 0}, tdescShape);
+  }
+
+  if (valueShape != tdescShape)
+    return emitOpError("Unexpected value shape")
+           << "(Expected shape: " << makeString(tdescShape)
+           << ", Given shape: " << makeString(valueShape) << ").\n";
+
   return success();
 }
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
index 35d44cf56a239b..a815f2b14b2000 100644
--- a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
+++ b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
@@ -24,8 +24,8 @@ gpu.func @test_create_nd_tdesc_vc_2(%src: ui64, %w : index, %h : index, %x : ind
 
 // CHECK: gpu.func @test_create_nd_tdesc_vc_3(%[[arg0:.*]]: memref<24x32xf32>) {
 gpu.func @test_create_nd_tdesc_vc_3(%src: memref<24x32xf32>) {
-  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x16xf32, #xegpu.tdesc_attr<array_length = 2 : i64>
-  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x16xf32, #xegpu.tdesc_attr<array_length = 2>>
+  // CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x16xf32, #xegpu.block_tdesc_attr<array_length = 2 : i64>
+  %1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x16xf32, #xegpu.block_tdesc_attr<array_length = 2>>
   gpu.return
 }
 
@@ -97,17 +97,17 @@ gpu.func @test_create_update_nd_tdesc_vc(%src: memref<24x32xf32>) {
 
 // CHECK: gpu.func @test_create_tdesc_vc(%[[arg0:.*]]: ui64) {
 gpu.func @test_create_tdesc_vc(%src: ui64) {
-  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %arg0 [0, 8, 16, 24] {chunk_size = 2 : i64} : ui64 -> !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>
-  %1 = xegpu.create_tdesc %src[0, 8, 16, 24] {chunk_size = 2} : ui64  -> !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %arg0 [0, 8, 16, 24] : ui64 -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+  %1 = xegpu.create_tdesc %src[0, 8, 16, 24] : ui64  -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
   gpu.return
 }
 
 // CHECK: gpu.func @test_prefetch_vc(%[[arg0:.*]]: ui64) {
 gpu.func @test_prefetch_vc(%src: ui64) {
-  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %arg0 [0, 8, 16, 24] {chunk_size = 2 : i64} : ui64 -> !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>
-  %1 = xegpu.create_tdesc %src[0, 8, 16, 24] {chunk_size = 2} : ui64  -> !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>
-  // CHECK: xegpu.prefetch %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>
-  xegpu.prefetch %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %arg0 [0, 8, 16, 24] : ui64 -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+  %1 = xegpu.create_tdesc %src[0, 8, 16, 24] : ui64  -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+  // CHECK: xegpu.prefetch %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+  xegpu.prefetch %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
   gpu.return
 }
 
@@ -115,12 +115,12 @@ gpu.func @test_prefetch_vc(%src: ui64) {
 gpu.func @test_load_gather_vc(%src: ui64) {
   //CHECK: %[[cst:.*]] = arith.constant dense<true> : vector<4xi1>
   %0 = arith.constant dense<1>: vector<4xi1>
-  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %arg0 [0, 8, 16, 24] {chunk_size = 2 : i64} : ui64 -> !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>
-  %1 = xegpu.create_tdesc %src[0, 8, 16, 24] {chunk_size = 2} : ui64  -> !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>
-  //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
-  //CHECK-SAME: !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>, vector<4xi1> -> vector<4x2xf32>
-  %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
-        : !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>, vector<4xi1> -> vector<4x2xf32>
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %arg0 [0, 8, 16, 24] : ui64 -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+  %1 = xegpu.create_tdesc %src[0, 8, 16, 24] : ui64  -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+  //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}>
+  //CHECK-SAME: !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<4xi1> -> vector<2x4xf32>
+  %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}>
+        : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1> -> vector<2x4xf32>
   gpu.return
 }
 
@@ -128,23 +128,23 @@ gpu.func @test_load_gather_vc(%src: ui64) {
 gpu.func @test_store_scatter_vc(%src: ui64) {
   //CHECK: %[[c0:.*]] = arith.constant dense<true> : vector<4xi1>
   %0 = arith.constant dense<1>: vector<4xi1>
-  //CHECK: %[[c1:.*]] = arith.constant dense<2.900000e+00> : vector<4x2xf32>
-  %1 = arith.constant dense<2.9>: vector<4x2xf32>
-  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %arg0 [0, 8, 16, 24] {chunk_size = 2 : i64} : ui64 -> !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>
-  %2 = xegpu.create_tdesc %src[0, 8, 16, 24] {chunk_size = 2} : ui64  -> !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>
-  //CHECK: xegpu.store %[[c1]], %[[R0]], %[[c0]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
-  //CHECK-SAME: vector<4x2xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>, vector<4xi1>
-  xegpu.store %1, %2, %0 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>}>
-        : vector<4x2xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>, vector<4xi1>
+  //CHECK: %[[c1:.*]] = arith.constant dense<2.900000e+00> : vector<2x4xf32>
+  %1 = arith.constant dense<2.9>: vector<2x4xf32>
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %arg0 [0, 8, 16, 24] : ui64 -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+  %2 = xegpu.create_tdesc %src[0, 8, 16, 24] : ui64  -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+  //CHECK: xegpu.store %[[c1]], %[[R0]], %[[c0]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}>
+  //CHECK-SAME: vector<2x4xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<4xi1>
+  xegpu.store %1, %2, %0 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}>
+        : vector<2x4xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1>
   gpu.return
 }
 
 // CHECK: gpu.func @test_create_update_tdesc_vc(%[[arg0:.*]]: ui64) {
 gpu.func @test_create_update_tdesc_vc(%src: ui64) {
-  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %arg0 [0, 8, 16, 24] {chunk_size = 2 : i64} : ui64 -> !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>
-  %1 = xegpu.create_tdesc %src[0, 8, 16, 24] {chunk_size = 2} : ui64  -> !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>
-  //CHECK: %[[R1:.*]] = xegpu.update_offset %[[R0]], [32, 32, 32, 32] : !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>
-  %2 = xegpu.update_offset %1, [32, 32, 32, 32] : !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %arg0 [0, 8, 16, 24] : ui64 -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+  %1 = xegpu.create_tdesc %src[0, 8, 16, 24]: ui64  -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+  //CHECK: %[[R1:.*]] = xegpu.update_offset %[[R0]], [32, 32, 32, 32] : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+  %2 = xegpu.update_offset %1, [32, 32, 32, 32] : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
   gpu.return
 }
 
@@ -165,10 +165,10 @@ gpu.func @test_dpas_vc_with_packed_b(%a : vector<8x16xf16>, %b: vector<8x16x2xf1
 
 // CHECK: gpu.func @test_atomic_rmw(%[[arg0:.*]]: ui64, %[[arg1:.*]]: vector<16xf32>, %[[arg2:.*]]: vector<16xi1>)
 gpu.func @test_atomic_rmw(%src: ui64, %value : vector<16xf32>, %mask : vector<16xi1>) {
-  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : ui64 -> !xegpu.tensor_desc<16xf32, #xegpu.tdesc_attr<scattered = true>>
-  %1 = xegpu.create_tdesc %src[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]: ui64 -> !xegpu.tensor_desc<16xf32, #xegpu.tdesc_attr<scattered = true>>
-  //CHECK: %[[R1:.*]] = xegpu.atomic_rmw addf %[[R0]], %[[arg2]], %[[arg1]] : !xegpu.tensor_desc<16xf32, #xegpu.tdesc_attr<scattered = true>>, vector<16xi1>, vector<16xf32> -> vector<16xf32>
-  xegpu.atomic_rmw addf %1, %mask, %value: !xegpu.tensor_desc<16xf32, #xegpu.tdesc_attr<scattered = true>>, vector<16xi1>, vector<16xf32> -> vector<16xf32>
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : ui64 -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+  %1 = xegpu.create_tdesc %src[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]: ui64 -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+  //CHECK: %[[R1:.*]] = xegpu.atomic_rmw addf %[[R0]], %[[arg2]], %[[arg1]] : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>, vector<16xf32> -> vector<16xf32>
+  xegpu.atomic_rmw addf %1, %mask, %value: !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>, vector<16xf32> -> vector<16xf32>
   gpu.return
 }
 
diff --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 7ef50bb2b5fadf..d2d1ad5273e9cd 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -26,10 +26,10 @@ func.func @test_prefetch_nd_vc_1(%src: memref<24x32xf16>) {
 // -----
 func.func @test_prefetch_nd_vc_2(%src: memref<24xf16>) {
   %1 = xegpu.create_tdesc %src[0, 1, 2, 3, 4, 5, 6, 7]
-        : memref<24xf16> -> !xegpu.tensor_desc<8xf16, #xegpu.tdesc_attr<scattered=true>>
+        : memref<24xf16> -> !xegpu.tensor_desc<8xf16, #xegpu.scatter_tdesc_attr<>>
   // expected-error at +1 {{Expects a non-scattered TensorDesc}}
   xegpu.prefetch_nd %1 <{l1_hint = #xegpu.cache_hint<cached>}>
-        : !xegpu.tensor_desc<8xf16, #xegpu.tdesc_attr<scattered=true>>
+        : !xegpu.tensor_desc<8xf16, #xegpu.scatter_tdesc_attr<>>
   return
 }
 
@@ -44,11 +44,11 @@ func.func @test_load_nd_vc_1(%src: memref<8x16xf16>) {
 
 // -----
 func.func @test_load_nd_vc_2(%src: memref<16xf16>) {
-  %1 = xegpu.create_tdesc %src[0, 2, 4, 6, 8, 10, 12, 14] {chunk_size = 2}
-        : memref<16xf16> -> !xegpu.tensor_desc<8x2xf16, #xegpu.tdesc_attr<scattered=true>>
+  %1 = xegpu.create_tdesc %src[0, 2, 4, 6, 8, 10, 12, 14]
+        : memref<16xf16> -> !xegpu.tensor_desc<8x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
   // expected-error at +1 {{Expects a non-scattered TensorDesc.}}
   %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>}>
-      : !xegpu.tensor_desc<8x2xf16, #xegpu.tdesc_attr<scattered=true>> -> vector<8x2xf16>
+      : !xegpu.tensor_desc<8x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2>> -> vector<8x2xf16>
   return
 }
 
@@ -73,28 +73,28 @@ func.func @test_store_nd_vc_1(%dst: memref<24x32xf16>) {
 // -----
 func.func @test_store_nd_vc_2(%dst: memref<16xf16>) {
   %1 = arith.constant dense<1.0>: vector<8x2xf16>
-  %2 = xegpu.create_tdesc %dst[0, 2, 4, 6, 8, 10, 12, 14] {chunk_size = 2}
-        : memref<16xf16> -> !xegpu.tensor_desc<8x2xf16, #xegpu.tdesc_attr<scattered=true>>
+  %2 = xegpu.create_tdesc %dst[0, 2, 4, 6, 8, 10, 12, 14]
+        : memref<16xf16> -> !xegpu.tensor_desc<8x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
   // expected-error at +1 {{Expects a non-scattered TensorDesc}}
   xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint<streaming>}>
-        : vector<8x2xf16>, !xegpu.tensor_desc<8x2xf16, #xegpu.tdesc_attr<scattered=true>>
+        : vector<8x2xf16>, !xegpu.tensor_desc<8x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
   return
 }
 
 // -----
 func.func @test_update_nd_offset_1(%dst: memref<16xf16>) {
-  %1 = xegpu.create_tdesc %dst[0, 2, 4, 6, 8, 10, 12, 14] {chunk_size = 2}
-        : memref<16xf16> -> !xegpu.tensor_desc<8x2xf16, #xegpu.tdesc_attr<scattered=true>>
+  %1 = xegpu.create_tdesc %dst[0, 2, 4, 6, 8, 10, 12, 14]
+        : memref<16xf16> -> !xegpu.tensor_desc<8x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
   // expected-error at +1 {{Expects a non-scattered TensorDesc}}
-  xegpu.update_nd_offset %1, [0, 2] : !xegpu.tensor_desc<8x2xf16, #xegpu.tdesc_attr<scattered=true>>
+  xegpu.update_nd_offset %1, [0, 2] : !xegpu.tensor_desc<8x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
   return
 }
 
 // -----
 func.func @test_create_tdesc_vc_1(%src: ui64) {
   // expected-error at +1 {{Expects a scattered TensorDesc}}
-  %1 = xegpu.create_tdesc %src[0, 2, 4, 6, 8, 10, 12, 14] {chunk_size = 2}
-        : ui64 -> !xegpu.tensor_desc<8x2xf16>
+  %1 = xegpu.create_tdesc %src[0, 2, 4, 6, 8, 10, 12, 14]
+        : ui64 -> !xegpu.tensor_desc<8xf16>
   return
 }
 
@@ -102,7 +102,7 @@ func.func @test_create_tdesc_vc_1(%src: ui64) {
 func.func @test_create_tdesc_vc_2(%src: ui64) {
   // expected-error at +1 {{Incorrect TensorDesc shape}}
   %1 = xegpu.create_tdesc %src[0, 2, 4, 6, 8, 10, 12, 14] {chunk_size = 2}
-        : ui64 -> !xegpu.tensor_desc<8x4xf16, #xegpu.tdesc_attr<scattered = true>>
+        : ui64 -> !xegpu.tensor_desc<8x4xf16, #xegpu.scatter_tdesc_attr<>>
   return
 }
 
@@ -116,9 +116,9 @@ func.func @test_prefetch_vc_1(%src: memref<24x32xf16>) {
 
 // -----
 func.func @test_prefetch_vc_2(%src: ui64) {
-  %1 = xegpu.create_tdesc %src[0, 8, 16, 24] {chunk_size = 2} : ui64  -> !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>
+  %1 = xegpu.create_tdesc %src[0, 8, 16, 24] : ui64  -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
   // expected-error at +1 {{invlid l1_hint: #xegpu.cache_hint<write_back>}}
-  xegpu.prefetch %1 <{l1_hint = #xegpu.cache_hint<write_back>}>: !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>
+  xegpu.prefetch %1 <{l1_hint = #xegpu.cache_hint<write_back>}>: !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
   return
 }
 
@@ -135,11 +135,11 @@ func.func @test_load_gather_vc_1(%src: memref<24x32xf16>) {
 // -----
 func.func @test_load_gather_vc_2(%src: ui64) {
   %0 = arith.constant dense<1>: vector<4xi1>
-  %1 = xegpu.create_tdesc %src[0, 8, 16, 24] {chunk_size = 2} : ui64
-        -> !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>
+  %1 = xegpu.create_tdesc %src[0, 8, 16, 24] : ui64
+        -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
   // expected-error at +1 {{invlid l1_hint: #xegpu.cache_hint<write_back>}}
   %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<write_back>}>
-        : !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>, vector<4xi1>
+        : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1>
           -> vector<4x2xf32>
   return
 }
@@ -159,11 +159,11 @@ func.func @test_store_scatter_vc_1(%src: memref<24x32xf32>) {
 func.func @test_store_scatter_vc_2(%src: ui64) {
   %0 = arith.constant dense<1>: vector<4xi1>
   %1 = arith.constant dense<2.9>: vector<4x2xf32>
-  %2 = xegpu.create_tdesc %src[0, 8, 16, 24] {chunk_size = 2}
-          : ui64 -> !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>
+  %2 = xegpu.create_tdesc %src[0, 8, 16, 24]
+          : ui64 -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
   // expected-error at +1 {{invlid l1_hint: #xegpu.cache_hint<streaming>}}
   xegpu.store %1, %2, %0 <{l1_hint = #xegpu.cache_hint<streaming>}> : vector<4x2xf32>,
-          !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>, vector<4xi1>
+          !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1>
   return
 }
 
@@ -182,9 +182,9 @@ func.func @test_dpas_vc_2(%a : vector<8x8x2xf16>, %b: vector<8x16x2xf16>) {
 }
 
 // -----
-func.func @test_atomic_rmw(%src: ui64, %value : vector<16x8xf32>, %mask : vector<16xi1>) {
-  %1 = xegpu.create_tdesc %src[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] {chunk_size = 8}: ui64 -> !xegpu.tensor_desc<16x8xf32, #xegpu.tdesc_attr<scattered = true>>
-  // expected-error at +1 {{failed to verify that all of {tensorDesc, mask, value, result} have same shape}}
-  xegpu.atomic_rmw addf %1, %mask, %value: !xegpu.tensor_desc<16x8xf32, #xegpu.tdesc_attr<scattered = true>>, vector<16xi1>, vector<16x8xf32> -> vector<16x8xf32>
-  gpu.return
+func.func @test_atomic_rmw(%src: ui64, %value : vector<16x4xf32>, %mask : vector<16xi1>) {
+  %1 = xegpu.create_tdesc %src[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : ui64 -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8>>
+  // expected-error at +1 {{failed to verify that all of {tensorDesc, value, result} have same shape}}
+  xegpu.atomic_rmw addf %1, %mask, %value: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8>>, vector<16xi1>, vector<16x4xf32> -> vector<16x8xf32>
+  return
 }
\ No newline at end of file

>From 1a50b1327a76d5f366c0e587e81d89ff49e1406f Mon Sep 17 00:00:00 2001
From: Chao Chen <chao.chen at intel.com>
Date: Wed, 18 Sep 2024 20:24:32 +0000
Subject: [PATCH 3/3] format the code

---
 mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp | 17 ++++++++++-------
 mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp     | 16 ++++++++++------
 2 files changed, 20 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 555c232ff1f063..4573045515601f 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -33,8 +33,9 @@ void XeGPUDialect::initialize() {
 // XeGPU_BlockTensorDescAttr
 //===----------------------------------------------------------------------===//
 BlockTensorDescAttr BlockTensorDescAttr::get(mlir::MLIRContext *context,
-                                        xegpu::MemoryScope memory_scope,
-                                        int array_length, bool boundary_check) {
+                                             xegpu::MemoryScope memory_scope,
+                                             int array_length,
+                                             bool boundary_check) {
   auto scopeAttr = MemoryScopeAttr::get(context, memory_scope);
   auto lengthAttr =
       IntegerAttr::get(IntegerType::get(context, 64), array_length);
@@ -45,9 +46,9 @@ BlockTensorDescAttr BlockTensorDescAttr::get(mlir::MLIRContext *context,
 //===----------------------------------------------------------------------===//
 // XeGPU_ScatterTensorDescAttr
 //===----------------------------------------------------------------------===//
-ScatterTensorDescAttr ScatterTensorDescAttr::get(mlir::MLIRContext *context,
-                                            xegpu::MemoryScope memory_scope,
-                                            int chunk_size) {
+ScatterTensorDescAttr
+ScatterTensorDescAttr::get(mlir::MLIRContext *context,
+                           xegpu::MemoryScope memory_scope, int chunk_size) {
   auto scopeAttr = MemoryScopeAttr::get(context, memory_scope);
   auto chunkSizeAttr =
       IntegerAttr::get(IntegerType::get(context, 64), chunk_size);
@@ -120,9 +121,11 @@ void TensorDescType::print(::mlir::AsmPrinter &printer) const {
 
 TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
                                    mlir::Type elementType, int array_length,
-                                   bool boundary_check, MemoryScope memory_scope) {
+                                   bool boundary_check,
+                                   MemoryScope memory_scope) {
   auto context = elementType.getContext();
-  auto attr = BlockTensorDescAttr::get(context, memory_scope, array_length, boundary_check);
+  auto attr = BlockTensorDescAttr::get(context, memory_scope, array_length,
+                                       boundary_check);
   return Base::get(context, shape, elementType, attr);
 }
 
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 0da38df90fdbd4..a4e9bbe58c83d4 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -307,9 +307,11 @@ LogicalResult CreateDescOp::verify() {
   auto chunkSize = tdescTy.getChunkSize();
 
   // check chunk_size
-  llvm::SmallVector<int64_t> supportedChunkSizes = {1, 2, 3, 4, 8, 16, 32, 64, 128, 256};
+  llvm::SmallVector<int64_t> supportedChunkSizes = {1,  2,  3,  4,   8,
+                                                    16, 32, 64, 128, 256};
   if (!llvm::is_contained(supportedChunkSizes, chunkSize))
-    return emitOpError("Invalid chunk_size. Supported values are 1, 2, 3, 4, 8, 16, 32, 64, 128, or 256.");
+    return emitOpError("Invalid chunk_size. Supported values are 1, 2, 3, 4, "
+                       "8, 16, 32, 64, 128, or 256.");
 
   // check total size
   auto elemBits = tdescTy.getElementType().getIntOrFloatBitWidth();
@@ -318,13 +320,16 @@ LogicalResult CreateDescOp::verify() {
     // For 8-bit and 16-bit data, the hardware only supports chunk size of 1.
     // For 32-bit data, the hardware can support larger larger chunk size. So
     // we can bitcast 8-bit/16-bit data to 32-bit data for better performance.
-    // But this requires the total size is 32 bit aligned to make the optimization work.
-    return emitOpError("access size (chunk_size * sizeof(elemTy)) should be 32-bit aligned.");
+    // But this requires the total size is 32 bit aligned to make the
+    // optimization work.
+    return emitOpError(
+        "access size (chunk_size * sizeof(elemTy)) should be 32-bit aligned.");
   }
 
   auto lscConstraints = 512 * 8; // each access is upto 512 bytes.
   if (elemBits * tdescTy.getNumElements() > lscConstraints)
-    return emitOpError("total access size (simd_lanes * chunk_size * sizeof(elemTy)) is upto 512 bytes.");
+    return emitOpError("total access size (simd_lanes * chunk_size * "
+                       "sizeof(elemTy)) is upto 512 bytes.");
 
   SmallVector<int64_t> shape({(int64_t)getNumOffsets()});
   if (chunkSize != 1)
@@ -397,7 +402,6 @@ LogicalResult LoadGatherOp::verify() {
     transpose({1, 0}, tdescShape);
   }
 
-
   if (valueShape != tdescShape)
     return emitOpError("Unexpected result shape")
            << "(Expected shape: " << makeString(tdescShape)



More information about the Mlir-commits mailing list