[Mlir-commits] [mlir] [MLIR][XeGPU] Updates XeGPU TensorDescAttr and Refine Gather/Scatter definition. (PR #109144)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Sep 19 10:50:24 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Chao Chen (chencha3)

<details>
<summary>Changes</summary>

The PR makes the following refine changes to the XeGPU dialect. 
1. Separated the old `TensorDescAttr` into two independent attributes, `BlockTensorDescAttr` and `ScatterTensorDescAttr` 
2. Renamed the `MemoryScopeAttr` to `MemorySpaceAttr`, and updated the enumeration value for shared memory following OpenCL standard. 
3. Introduced `transpose` UnitAttr to `StoreScatterOp`and `LoadGatherOp`
4. Added memory space check for `CreateNdDesc` and `CreateDesc` op, addes valid and invalid test cases for them. 

---

Patch is 52.14 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/109144.diff


7 Files Affected:

- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td (+45-19) 
- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td (+61-25) 
- (modified) mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td (+47-26) 
- (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp (+33-13) 
- (modified) mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp (+81-28) 
- (modified) mlir/test/Dialect/XeGPU/XeGPUOps.mlir (+44-30) 
- (modified) mlir/test/Dialect/XeGPU/invalid.mlir (+48-27) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
index f3ca09a6a68ea8..26eec0d4f2082a 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
@@ -19,12 +19,18 @@ class XeGPUAttr<string name, string attrMnemonic, list<Trait> traits = [],
   let mnemonic = attrMnemonic;
 }
 
-def XeGPU_TensorDescAttr: XeGPUAttr<"TensorDesc", "tdesc_attr"> {
+class XeGPU_TensorDescAttr<string name, string attrMnemonic, list<Trait> traits = [],
+                         string baseCppClass = "::mlir::Attribute">
+    : XeGPUAttr<name, attrMnemonic, traits, baseCppClass> {
+  let assemblyFormat = "`<` struct(params) `>`";
+}
+
+def XeGPU_BlockTensorDescAttr: XeGPU_TensorDescAttr<"BlockTensorDesc", "block_tdesc_attr"> {
   let summary = [{a composite attribute for `TensorDescType`}];
-  let description = [{`TensorDescAttr` (or `tdesc_attr`) is a composite
+  let description = [{`BlockTensorDesc` (or `block_tdesc_attr`) is a composite
     attribute defined for `TensorDescType` for describing following
     properties of a `TensorDesc`.
-    1. `memory_scope`: It describes where the data block described by the
+    1. `memory_space`: It describes where the data block described by the
         TensorDesc is located, `Global` device memory or `Shared` local memory.
         It is default to `Global`.
     2. `array_length`: It describes how many horizontally consecutive blocks
@@ -33,43 +39,63 @@ def XeGPU_TensorDescAttr: XeGPUAttr<"TensorDesc", "tdesc_attr"> {
         8x32. Its default value is 1.
     3. `boundary_check`: It is used to indicates the hardware whether to do
         out-of-boundary check. The default value is true.
-    4. `scattered`: It is used to differenciate TensorDescs created from
-       `create_nd_tdesc` vs from `create_tdesc`.
   }];
 
   let parameters = (ins
-    OptionalParameter<"MemoryScopeAttr">: $memory_scope,
+    OptionalParameter<"MemorySpaceAttr">: $memory_space,
     OptionalParameter<"IntegerAttr", "1">: $array_length,
-    OptionalParameter<"BoolAttr", "true">: $boundary_check,
-    OptionalParameter<"BoolAttr", "false">: $scattered
+    OptionalParameter<"BoolAttr", "true">: $boundary_check
   );
 
   let builders = [
     AttrBuilder<(ins
-      CArg<"xegpu::MemoryScope", "xegpu::MemoryScope::Global">:$memory_scope,
+      CArg<"xegpu::MemorySpace", "xegpu::MemorySpace::Global">:$memory_space,
       CArg<"int", "1">:$array_length,
-      CArg<"bool", "true">: $boundary_check,
-      CArg<"bool", "false">: $scattered
+      CArg<"bool", "true">: $boundary_check
     )>
   ];
 
-  let assemblyFormat = "`<` struct(params) `>`";
 }
 
+def XeGPU_ScatterTensorDescAttr: XeGPU_TensorDescAttr<"ScatterTensorDesc", "scatter_tdesc_attr"> {
+  let summary = [{a composite attribute for `TensorDescType`}];
+  let description = [{`ScatterTensorDesc` (or `scatter_tdesc_attr`) is a composite
+    attribute defined for `TensorDescType` for describing following
+    properties of a `TensorDesc`.
+    1. `memory_space`: It describes where the data block described by the
+        TensorDesc is located, `Global` device memory or `Shared` local memory.
+        It is default to `Global`.
+    2.  `chunk_size`: indicates number of continious elements accessed for each
+        offset, default is 1. It is used with `scattered` attr only.
+  }];
+
+  let parameters = (ins
+    OptionalParameter<"MemorySpaceAttr">: $memory_space,
+    OptionalParameter<"IntegerAttr", "1">: $chunk_size
+  );
+
+  let builders = [
+    AttrBuilder<(ins
+      CArg<"xegpu::MemorySpace", "xegpu::MemorySpace::Global">:$memory_space,
+      CArg<"int", "1">: $chunk_size
+    )>
+  ];
+ }
+
 //===----------------------------------------------------------------------===//
 // XeGPU Memory Scope Enums.
 //===----------------------------------------------------------------------===//
-def XeGPU_MemoryScopeGlobal: I32EnumAttrCase<"Global", 0, "global">;
-def XeGPU_MemoryScopeShared: I32EnumAttrCase<"SLM", 1, "slm">;
-def XeGPU_MemoryScope: I32EnumAttr<"MemoryScope",
+def XeGPU_MemorySpaceGlobal: I32EnumAttrCase<"Global", 0, "global">;
+def XeGPU_MemorySpaceShared: I32EnumAttrCase<"SLM", 3, "slm">;
+def XeGPU_MemorySpace: I32EnumAttr<"MemorySpace",
       "The address space of the memory the tensor descritor is created for",
-      [XeGPU_MemoryScopeGlobal, XeGPU_MemoryScopeShared]> {
+      [XeGPU_MemorySpaceGlobal, XeGPU_MemorySpaceShared]> {
   let genSpecializedAttr = 0;
   let cppNamespace = "::mlir::xegpu";
 }
 
-def XeGPU_MemoryScopeAttr:
-  EnumAttr<XeGPU_Dialect, XeGPU_MemoryScope, "memory_scope"> {
+def XeGPU_MemorySpaceAttr:
+  EnumAttr<XeGPU_Dialect, XeGPU_MemorySpace, "memory_space"> {
     let summary = [{Describe the location of data described by a `TensorDesc`:
                  Global device memory (`Global`) or Shared local memory (`SLM`).}];
     let assemblyFormat = "$value";
@@ -116,4 +142,4 @@ def XeGPU_FenceScopeAttr:
     let assemblyFormat = "$value";
 }
 
-#endif // MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD
\ No newline at end of file
+#endif // MLIR_DIALECT_XEGPU_IR_XEGPUATTRS_TD
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index c32c7541c39791..e24a056de2caf3 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -218,6 +218,23 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
     static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 1; }
 
     mlir::Value getViewSource() { return getSource(); }
+
+    unsigned getSourceMemorySpace() {
+      auto srcTy = getSourceType();
+      if (auto memrefTy = llvm::dyn_cast<mlir::MemRefType>(srcTy)) {
+        auto attr = memrefTy.getMemorySpace();
+        if (attr) {
+          if (auto intAttr = llvm::dyn_cast<mlir::IntegerAttr>(attr)) {
+            return static_cast<unsigned>(intAttr.getInt());
+          }
+          if (auto memSpaceAttr = llvm::dyn_cast<MemorySpaceAttr>(attr))
+            return static_cast<unsigned>(memSpaceAttr.getValue());
+        }
+      }
+      // take global as default memory scope.
+      return static_cast<unsigned>(MemorySpace::Global);
+    }
+
   }];
 }
 
@@ -411,8 +428,10 @@ def XeGPU_CreateDescOp: XeGPU_Op<"create_tdesc", [Pure, ViewLikeOpInterface]> {
       is fixed to the hardware supportted subgroup size, e.g., 16 on PVC,
       implying each element in the array corresponds to a work-item (SIMT lane)
       in the subgroup.
-    * chunk_size: [optional attribute] indicates number of continious
-      elements accessed for each offset, default is 1.
+
+    The first dimension of the result TensorDesc corresponds to work-items, so it should
+    match the dimension of offsets. It may also has a second dimension corresponding to
+    the chunk_size if the chunk size is larger than 1.
 
     Example 1. It assumes subgroup size is 4, and accesses a[0], a[16], a[32], a[64]
     ```mlir
@@ -424,29 +443,22 @@ def XeGPU_CreateDescOp: XeGPU_Op<"create_tdesc", [Pure, ViewLikeOpInterface]> {
                It will access totally 32 data elements: a[0:7], a[16:23], a[32:39], a[64:71]
     ```mlir
     %0 = memref.alloc() : memref<1024xf32>
-    %1 = xegpu.create_tdesc %0[0, 16, 32, 64] {chunk_size = 8}: memref<1024xf32> -> TensorDesc<4x8xf32>
+    %1 = xegpu.create_tdesc %0[0, 16, 32, 64] : memref<1024xf32> -> TensorDesc<4x8xf32, chunk_size = 8>
     ```
 
     Example 3. It is similar to Example 2, but there is some overlaps among workitems.
                It accesses: a[0:7], a[4:11], a[8:15], a[12:19]
     ```mlir
     %0 = memref.alloc() : memref<1024xf32>
-    %1 = xegpu.create_tdesc %0[0, 4, 8, 12] {chunk_size = 8}: memref<1024xf32> -> TensorDesc<4x8xf32>
+    %1 = xegpu.create_tdesc %0[0, 4, 8, 12] : memref<1024xf32> -> TensorDesc<4x8xf32, chunk_size = 8>>
     ```
   }];
 
   let arguments = (ins XeGPU_BaseAddrType: $source,
                        Variadic<Index>: $offsets,
-                       DenseI64ArrayAttr: $const_offsets,
-                       DefaultValuedAttr<I64Attr, "1">: $chunk_size);
+                       DenseI64ArrayAttr: $const_offsets);
   let results = (outs XeGPU_TensorDesc:$TensorDesc);
 
-  let builders = [
-    OpBuilder<(ins "xegpu::TensorDescType": $TensorDesc, "Value": $source,
-                   "llvm::ArrayRef<OpFoldResult>": $offsets,
-                   CArg<"uint32_t", "1"> : $chunk_size)>,
-  ];
-
   let assemblyFormat = [{
     $source
     custom<DynamicIndexList>($offsets, $const_offsets)
@@ -473,6 +485,22 @@ def XeGPU_CreateDescOp: XeGPU_Op<"create_tdesc", [Pure, ViewLikeOpInterface]> {
       assert(idx < getNumOffsets() && "Invalid out of bound access.");
       return getMixedOffsets()[idx];
     }
+
+    unsigned getSourceMemorySpace() {
+      auto srcTy = getSource().getType();
+      if (auto memrefTy = llvm::dyn_cast<mlir::MemRefType>(srcTy)) {
+        auto attr = memrefTy.getMemorySpace();
+        if (attr) {
+          if (auto intAttr = llvm::dyn_cast<mlir::IntegerAttr>(attr))
+            return static_cast<unsigned>(intAttr.getInt());
+          if (auto memSpaceAttr = llvm::dyn_cast<MemorySpaceAttr>(attr))
+            return static_cast<unsigned>(memSpaceAttr.getValue());
+        }
+      }
+      // take global as default memory scope.
+      return static_cast<unsigned>(MemorySpace::Global);
+    }
+
   }];
 
   let hasVerifier = 1;
@@ -520,28 +548,31 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [AllRanksMatch<["value", "TensorDesc"]
 
   let description = [{ It (aka. load) load data per each work-item. The output
     describes the data being loaded at the subgroup level, so its size is
-    consistent with the number of work-items in a subgroup. When `chunk_size_per_lane`
-    attribute is larger than 1 in TensorDesc, the output vector will be 2D vector,
-    with dim-1 correspoding to the chunk size.
+    consistent with the number of work-items in a subgroup. When the chunk size
+    is larger than 2, the output vector is a 2D vector, with dim-1 correspoding
+    to work-items, and dim-0 corresponding to the chunk_size loaded by each work-item.
+    Specially, there is a transpose effect on the result (as compared to the TensorDesc)
+    due to the hardware implementation. Therefore, a transpose attribute is introduced
+    on purpose, making sure users are aware of this implicit transformation.
 
     The mask operand masks out memory access so that it is safe to pass out-of-boundary
     addresses/offsets as long as they are masked. It applies to slots of SIMD lanes.
 
   Example:
   ```mlir
-    %2 = xegpu.load %1, %0 {transpose = [1, 0],
+    %2 = xegpu.load %1, %0 {transpose,
                             l1_hint = #xegpu.cache_hint<cached>,
                             l2_hint = #xegpu.cache_hint<uncached>,
                             l3_hint = #xegpu.cache_hint<uncached>}
-          : !xegpu.tensor_desc<16xf32, #xegpu.tdesc_attr<scattered=true>>, vector<16xi1>
-            -> vector<16xf32>
+          : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<memory_space=global>>,
+            vector<16xi1> -> vector<16xf32>
   ```
 
   }];
 
   let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
                        XeGPU_MaskType: $mask,
-                       OptionalAttr<DenseI64ArrayAttr>: $transpose,
+                       OptionalAttr<UnitAttr>: $transpose,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
                        OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
@@ -573,11 +604,15 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [AllRanksMatch<["value", "TensorDesc"]
   let hasVerifier = 1;
 }
 
-def XeGPU_StoreScatterOp : XeGPU_Op<"store", [AllShapesMatch<["value", "TensorDesc"]>,
-                                        AllElementTypesMatch<["value", "TensorDesc"]>]> {
+def XeGPU_StoreScatterOp : XeGPU_Op<"store", [AllElementCountsMatch<["value", "TensorDesc"]>,
+                                              AllElementTypesMatch<["value", "TensorDesc"]>]> {
   let summary = "store data to scattered memory locations.";
-  let description = [{ It (aka. store) stores data to scattered memory locations.
-  It has similar semantic to `load_gather`.
+  let description = [{ It (aka. store) stores data to scattered memory locations. The value is
+  typically a 1D vector. But when the chunk size of the TensorDesc is larger than 1, it will be
+  a 2D vector instead. For the later case, dim-1 of the value correspods to the simd lanes
+  and the dim-0 of the value corresponds to the chunk_size stored per lane. So `store_scatter`
+  has transpose effect, which is similar to `load_gather`. Therefore, a transpose attribute is
+  introduced on purpose, making sure users are aware of this implicit transformation.
 
   Example:
   ```mlir
@@ -592,6 +627,7 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [AllShapesMatch<["value", "TensorDe
     XeGPU_ValueType: $value,
     XeGPU_TensorDesc: $TensorDesc,
     XeGPU_MaskType: $mask,
+    OptionalAttr<UnitAttr>: $transpose,
     OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
     OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
     OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
@@ -723,7 +759,7 @@ def XeGPU_DpasOp : XeGPU_Op<"dpas", [Pure, AllElementTypesMatch<["lhs", "rhs"]>]
 
 def XeGPU_AtomicRMWOp: XeGPU_Op<"atomic_rmw", [Pure,
       AllElementTypesMatch<["tensorDesc", "value", "result"]>,
-      AllShapesMatch<["tensorDesc", "mask", "value", "result"]>]> {
+      AllShapesMatch<["tensorDesc", "value", "result"]>]> {
   let summary = "Atomic ready-modify-write operation on the TensorDesc. ";
 
   let description = [{
@@ -808,7 +844,7 @@ def XeGPU_FenceOp: XeGPU_Op<"fence", []> {
     2. `Fence_scope` describes the scope of fence. "Workgroup" means that the scope would be
         within each workgroup. "GPU" means the scope would be across workgroups within the GPU.
   }];
-  let arguments = (ins XeGPU_MemoryScopeAttr: $memory_kind,
+  let arguments = (ins XeGPU_MemorySpaceAttr: $memory_kind,
                        XeGPU_FenceScopeAttr: $fence_scope);
   let assemblyFormat = [{`memory_kind` `=` `` $memory_kind `,` `fence_scope` `=` `` $fence_scope attr-dict}];
   let extraClassDeclaration = extraBaseClassDeclaration;
diff --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
index 9f101a71697b56..0ce1211664b5ba 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
@@ -48,7 +48,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
 
     Similar to the builtin tensor, it also provides an optinal attribute to encoding
     the following information via the TensorDescAttr object:
-    * memory_scope (xegpu::MemoryScope): [optional] where the data is located,
+    * memory_space (xegpu::MemorySpace): [optional] where the data is located,
                 global memory or shared memory. It is default to Global.
     * array_length (int): [optional] The number of contiguous blocks with size as `shape`,
                that will be loaded by block load at a time. It is default to 1.
@@ -63,7 +63,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
     element-type ::= float-type | integer-type | index-type
     dim-list := (static-dim-list `x`)?
     static-dim-list ::= decimal-literal `x` decimal-literal
-    attr-list = (, memory_scope = value)? (, arr_len = value)? (, boundary_check = value)? (, scattered = value)?
+    attr-list = (, memory_space = value)? (, arr_len = value)? (, boundary_check = value)? (, scattered = value)?
     ```
 
     Examples:
@@ -76,7 +76,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
     xegpu.tensor_desc<8x16xf32>
 
     // A TensorDesc with 8x16 f32 elements for a memory region in shared memory space.
-    xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr<memory_scope = slm>>
+    xegpu.tensor_desc<8x16xf32, #xegpu.tdesc_attr<memory_space = slm>>
     ```
   }];
 
@@ -88,11 +88,14 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
     TypeBuilderWithInferredContext<(ins
       "llvm::ArrayRef<int64_t>": $shape,
       "mlir::Type": $elementType,
-      CArg<"bool", "false">: $scattered,
       CArg<"int", "1">: $array_length,
-      CArg<"xegpu::MemoryScope", "xegpu::MemoryScope::Global">:$memory_scope,
-      CArg<"bool", "true">: $boundary_check
-    )>
+      CArg<"bool", "true">: $boundary_check,
+      CArg<"xegpu::MemorySpace", "xegpu::MemorySpace::Global">:$memory_space)>,
+    TypeBuilderWithInferredContext<(ins
+      "llvm::ArrayRef<int64_t>": $shape,
+      "mlir::Type": $elementType,
+      CArg<"int", "1">: $chunk_size,
+      CArg<"xegpu::MemorySpace", "xegpu::MemorySpace::Global">:$memory_space)>
   ];
 
   let extraClassDeclaration = [{
@@ -110,40 +113,58 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
       return llvm::cast<TensorDescType>(cloneWith(getShape(), elementType));
     }
 
-    TensorDescAttr getEncodingAsTensorDescAttr() const {
-      return llvm::dyn_cast_if_present<TensorDescAttr>(getEncoding());
+    BlockTensorDescAttr getEncodingAsBlockTensorDescAttr() const {
+      return llvm::dyn_cast_if_present<BlockTensorDescAttr>(getEncoding());
     }
 
-    xegpu::MemoryScope getMemoryScope() const {
-      auto attr = getEncodingAsTensorDescAttr();
-      if (attr && attr.getMemoryScope())
-        return attr.getMemoryScope().getValue();
+    ScatterTensorDescAttr getEncodingAsScatterTensorDescAttr() const {
+      return llvm::dyn_cast_if_present<ScatterTensorDescAttr>(getEncoding());
+    }
+
+    xegpu::MemorySpace getMemorySpace() const {
+      auto block_attr = getEncodingAsBlockTensorDescAttr();
+      if (block_attr && block_attr.getMemorySpace())
+        return block_attr.getMemorySpace().getValue();
+
+      auto scatter_attr = getEncodingAsScatterTensorDescAttr();
+      if (scatter_attr && scatter_attr.getMemorySpace())
+        return scatter_attr.getMemorySpace().getValue();
+
       // return default value
-      return MemoryScope::Global;
+      return MemorySpace::Global;
     }
 
     int getArrayLength() {
-      auto attr = getEncodingAsTensorDescAttr();
-      if (attr && attr.getArrayLength())
-        return attr.getArrayLength().getInt();
+      auto attr = getEncoding();
+      auto block_attr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(attr);
+      assert((!attr || block_attr) && "invalid on non BlockTensorDescAttr.");
+      if (block_attr && block_attr.getArrayLength())
+        return block_attr.getArrayLength().getInt();
       // return default value
       return 1;
     }
 
     bool getBoundaryCheck() {
-      auto attr = getEncodingAsTensorDescAttr();
-      if (attr && attr.getBoundaryCheck())
-        return attr.getBoundaryCheck().getValue();
+      auto attr = getEncoding();
+      auto block_attr = mlir::dyn_cast_if_present<BlockTensorDescAttr>(attr);
+      assert((!attr || block_attr) && "invalid on non BlockTensorDescAttr.");
+      if (block_attr && block_attr.getBoundaryCheck())
+        return block_attr.getBoundaryCheck().getValue();
       // return default value
       return true;
     }
 
-    bool getScattered() {
-      auto attr = getEncodingAsTensorDescAttr();
-      if (attr && attr.getScattered())
-        return attr.getScattered().getValue();
-      // return default value
-      return false;
+    bool isScattered() {
+      return bool(getEncodingAsScatterTensorDescAttr());
+    }
+
+    int getChunkSize() {
+      auto attr = getEncoding();
+      auto scatter_attr = mlir::dyn_cast_if_present<ScatterTensorDescAttr>(attr);
+      assert((!attr || scatter_attr) && "invalid on non ScatterTensorDescAttr.");
+      if (scatter_attr && scatter_attr.getChunkSize())
+        return scatter_attr.getChunkSize().getInt();
+      return 1;
     }
   }];
 
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 24719fe748fe4f..1dfbaed454c193 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -30,23 +30,35 @@ void XeGPUDialect::initialize() {
 }
 
 //===-------------------------------------------...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/109144


More information about the Mlir-commits mailing list