[Mlir-commits] [mlir] 9c697b3 - [MLIR][XeGPU] Update the type of offsets for CreateDescOp and UpdateOffsetOp (#110741)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Oct 2 07:18:45 PDT 2024


Author: Chao Chen
Date: 2024-10-02T09:18:41-05:00
New Revision: 9c697b3a02d95b49e11633c45f76f77954fca704

URL: https://github.com/llvm/llvm-project/commit/9c697b3a02d95b49e11633c45f76f77954fca704
DIFF: https://github.com/llvm/llvm-project/commit/9c697b3a02d95b49e11633c45f76f77954fca704.diff

LOG: [MLIR][XeGPU] Update the type of offsets for CreateDescOp and UpdateOffsetOp (#110741)

This PR changes the type of `offsets` operand of CreateDescOp and
UpdateOffsetOp to 1D Vector of index, for convenience of users.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
    mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
    mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
    mlir/test/Dialect/XeGPU/XeGPUOps.mlir
    mlir/test/Dialect/XeGPU/invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
index 7ac0cf77fe59bb..d6c51d20571fd3 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
@@ -11,6 +11,7 @@
 
 #include "mlir/Bytecode/BytecodeOpInterface.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/TypeUtilities.h"

diff  --git a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
index e24a056de2caf3..239ce0aa8e0035 100644
--- a/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
+++ b/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
@@ -424,9 +424,9 @@ def XeGPU_CreateDescOp: XeGPU_Op<"create_tdesc", [Pure, ViewLikeOpInterface]> {
     It accepts the following parameters:
 
     * source: a 1D memref or pointer (uint64_t) represents the flattened memory object.
-    * offsets: a array containing offsets of each access point. Its size
+    * offsets: a vector containing offsets of each access point. Its size
       is fixed to the hardware supportted subgroup size, e.g., 16 on PVC,
-      implying each element in the array corresponds to a work-item (SIMT lane)
+      implying each element in the vector corresponds to a work-item (SIMT lane)
       in the subgroup.
 
     The first dimension of the result TensorDesc corresponds to work-items, so it should
@@ -436,56 +436,59 @@ def XeGPU_CreateDescOp: XeGPU_Op<"create_tdesc", [Pure, ViewLikeOpInterface]> {
     Example 1. It assumes subgroup size is 4, and accesses a[0], a[16], a[32], a[64]
     ```mlir
     %a = memref.alloc() : memref<1024xf32>
-    %1 = xegpu.create_tdesc %a[0, 16, 32, 64]: memref<1024xf32> -> TensorDesc<4xf32>
+    %0 = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex>
+    %1 = xegpu.create_tdesc %a, %0: memref<1024xf32>, vector<4xindex> -> TensorDesc<4xf32>
     ```
 
     Example 2. It assumes subgroup size is 4, and each workitem access 8 elements.
                It will access totally 32 data elements: a[0:7], a[16:23], a[32:39], a[64:71]
     ```mlir
     %0 = memref.alloc() : memref<1024xf32>
-    %1 = xegpu.create_tdesc %0[0, 16, 32, 64] : memref<1024xf32> -> TensorDesc<4x8xf32, chunk_size = 8>
+    %off = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex>
+    %1 = xegpu.create_tdesc %0, %off : memref<1024xf32>, vector<4xindex>
+          -> TensorDesc<4x8xf32, #xegpu.scattered_tdesc_attr<chunk_size = 8>>
     ```
 
     Example 3. It is similar to Example 2, but there is some overlaps among workitems.
                It accesses: a[0:7], a[4:11], a[8:15], a[12:19]
     ```mlir
     %0 = memref.alloc() : memref<1024xf32>
-    %1 = xegpu.create_tdesc %0[0, 4, 8, 12] : memref<1024xf32> -> TensorDesc<4x8xf32, chunk_size = 8>>
+    %off = arith.constant dense<[0, 4, 8, 12]> : vector<4xindex>
+    %1 = xegpu.create_tdesc %0, %off : memref<1024xf32>, vector<4xindex>
+          -> TensorDesc<4x8xf32, #xegpu.scattered_tdesc_attr<chunk_size = 8>>
     ```
   }];
 
   let arguments = (ins XeGPU_BaseAddrType: $source,
-                       Variadic<Index>: $offsets,
-                       DenseI64ArrayAttr: $const_offsets);
+                       XeGPU_OffsetType: $offsets);
   let results = (outs XeGPU_TensorDesc:$TensorDesc);
 
+  let builders = [
+    OpBuilder<(ins "xegpu::TensorDescType": $TensorDesc, "mlir::Value": $source,
+                   "llvm::ArrayRef<OpFoldResult>": $offsets)>,
+    OpBuilder<(ins "xegpu::TensorDescType": $TensorDesc, "mlir::Value": $source,
+                   "llvm::ArrayRef<int64_t>": $offsets)>,
+  ];
+
   let assemblyFormat = [{
-    $source
-    custom<DynamicIndexList>($offsets, $const_offsets)
-    attr-dict `:`  type($source) `->` qualified(type($TensorDesc))
+    $source `,` $offsets attr-dict `:`  type($source) `,` type($offsets) `->` qualified(type($TensorDesc))
   }];
 
-  let extraClassDeclaration = extraBaseClassDeclaration # [{
+  let extraClassDeclaration = [{
     xegpu::TensorDescType getTensorDescType() {
       return getTensorDesc().getType();
     }
 
-    SmallVector<OpFoldResult> getMixedOffsets() {
-      Builder b(getContext());
-      return getMixedValues(getConstOffsets(), getOffsets(), b);
+    mlir::VectorType getOffsetsType() {
+      return getOffsets().getType();
     }
 
     size_t getNumOffsets() {
-      return getMixedOffsets().size();
+      return getOffsetsType().getNumElements();
     }
 
     mlir::Value getViewSource() { return getSource(); }
 
-    OpFoldResult getOffset(unsigned idx) {
-      assert(idx < getNumOffsets() && "Invalid out of bound access.");
-      return getMixedOffsets()[idx];
-    }
-
     unsigned getSourceMemorySpace() {
       auto srcTy = getSource().getType();
       if (auto memrefTy = llvm::dyn_cast<mlir::MemRefType>(srcTy)) {
@@ -550,7 +553,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [AllRanksMatch<["value", "TensorDesc"]
     describes the data being loaded at the subgroup level, so its size is
     consistent with the number of work-items in a subgroup. When the chunk size
     is larger than 2, the output vector is a 2D vector, with dim-1 correspoding
-    to work-items, and dim-0 corresponding to the chunk_size loaded by each work-item.
+    to work-items, and dim-0 corresponding to the chunk size loaded by each work-item.
     Specially, there is a transpose effect on the result (as compared to the TensorDesc)
     due to the hardware implementation. Therefore, a transpose attribute is introduced
     on purpose, making sure users are aware of this implicit transformation.
@@ -558,16 +561,25 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [AllRanksMatch<["value", "TensorDesc"]
     The mask operand masks out memory access so that it is safe to pass out-of-boundary
     addresses/offsets as long as they are masked. It applies to slots of SIMD lanes.
 
-  Example:
+  Example 1:
   ```mlir
-    %2 = xegpu.load %1, %0 {transpose,
-                            l1_hint = #xegpu.cache_hint<cached>,
+    %2 = xegpu.load %1, %0 {l1_hint = #xegpu.cache_hint<cached>,
                             l2_hint = #xegpu.cache_hint<uncached>,
                             l3_hint = #xegpu.cache_hint<uncached>}
           : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<memory_space=global>>,
             vector<16xi1> -> vector<16xf32>
   ```
 
+  Example 2:
+  ```mlir
+    %2 = xegpu.load %1, %0 {transpose,
+                            l1_hint = #xegpu.cache_hint<cached>,
+                            l2_hint = #xegpu.cache_hint<uncached>,
+                            l3_hint = #xegpu.cache_hint<uncached>}
+          : !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>,
+            vector<16xi1> -> vector<8x16xf32>
+  ```
+
   }];
 
   let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
@@ -610,17 +622,27 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [AllElementCountsMatch<["value", "T
   let description = [{ It (aka. store) stores data to scattered memory locations. The value is
   typically a 1D vector. But when the chunk size of the TensorDesc is larger than 1, it will be
   a 2D vector instead. For the later case, dim-1 of the value correspods to the simd lanes
-  and the dim-0 of the value corresponds to the chunk_size stored per lane. So `store_scatter`
+  and the dim-0 of the value corresponds to the chunk size stored per lane. So `store_scatter`
   has transpose effect, which is similar to `load_gather`. Therefore, a transpose attribute is
   introduced on purpose, making sure users are aware of this implicit transformation.
 
-  Example:
+  Example 1:
   ```mlir
     %3 = xegpu.store %0, %1, %2 {l1_hint = #xegpu.cache_hint<uncached>,
                                  l2_hint = #xegpu.cache_hint<write_back>,
                                  l3_hint = #xegpu.cache_hint<write_through>}
-          : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.tdesc_attr<scattered=true>>, vector<16xi1>
+          : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scattered_tdesc_attr<>>, vector<16xi1>
+  ```
+
+  Example 2:
+  ```mlir
+    %3 = xegpu.store %0, %1, %2 {transpose,
+                                 l1_hint = #xegpu.cache_hint<uncached>,
+                                 l2_hint = #xegpu.cache_hint<write_back>,
+                                 l3_hint = #xegpu.cache_hint<write_through>}
+          : vector<8x16xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>>, vector<16xi1>
   ```
+
   }];
 
   let arguments = (ins
@@ -666,40 +688,39 @@ def XeGPU_UpdateOffsetOp: XeGPU_Op<"update_offset",
 
     Example:
     ```mlir
-      %2 = xegpu.update_offset %1, [32, 32, 32, 32]
-            : !xegpu.tensor_desc<4x2xf32, #xegpu.tdesc_attr<scattered = true>>
+      %off = arith.constant dense<[32, 32, 32, 32]> : vector<4xindex>
+      %2 = xegpu.update_offset %1, %off :
+              !xegpu.tensor_desc<4x2xf32, #xegpu.scattered_tdesc_attr<>>, vector<4xindex>
     ```
   }];
 
   let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
-                       Variadic<Index>: $offsets,
-                       DenseI64ArrayAttr: $const_offsets);
+                       XeGPU_OffsetType: $offsets);
   let results = (outs XeGPU_TensorDesc: $result);
 
-  let extraClassDeclaration = extraBaseClassDeclaration # [{
+  let builders = [
+    OpBuilder<(ins "mlir::Value": $TensorDesc,
+                   "llvm::ArrayRef<OpFoldResult>": $offsets)>,
+    OpBuilder<(ins "mlir::Value": $TensorDesc,
+                   "llvm::ArrayRef<int64_t>": $offsets)>
+  ];
+
+  let extraClassDeclaration = [{
     xegpu::TensorDescType getTensorDescType() {
       return getTensorDesc().getType();
     }
 
-    SmallVector<OpFoldResult> getMixedOffsets() {
-      Builder b(getContext());
-      return getMixedValues(getConstOffsets(), getOffsets(), b);
+    mlir::VectorType getOffsetsType() {
+      return getOffsets().getType();
     }
 
     size_t getNumOffsets() {
-      return getMixedOffsets().size();
-    }
-
-    OpFoldResult getOffset(unsigned idx) {
-      assert(idx < getNumOffsets() && "Invalid out of bound access.");
-      return getMixedOffsets()[idx];
+      return getOffsetsType().getNumElements();
     }
   }];
 
   let assemblyFormat = [{
-    $TensorDesc `,`
-    custom<DynamicIndexList>($offsets, $const_offsets)
-    attr-dict `:` qualified(type($TensorDesc))
+    $TensorDesc `,` $offsets attr-dict `:` qualified(type($TensorDesc)) `,` type($offsets)
   }];
 }
 

diff  --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 1a7a6b34784099..5bd3c370e38594 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -6,6 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Dialect/XeGPU/IR/XeGPU.h"
 #include "mlir/IR/Builders.h"
@@ -308,6 +309,24 @@ LogicalResult UpdateNdOffsetOp::verify() {
 // XeGPU_CreateDescOp
 //===----------------------------------------------------------------------===//
 
+void CreateDescOp::build(OpBuilder &builder, OperationState &state,
+                         TensorDescType TensorDesc, Value source,
+                         llvm::ArrayRef<OpFoldResult> offsets) {
+  auto loc = source.getLoc();
+  int64_t size = static_cast<int64_t>(offsets.size());
+  auto type = VectorType::get(size, builder.getIndexType());
+  auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
+  auto offset = builder.create<vector::FromElementsOp>(loc, type, values);
+  build(builder, state, TensorDesc, source, offset);
+}
+
+void CreateDescOp::build(OpBuilder &builder, OperationState &state,
+                         TensorDescType TensorDesc, Value source,
+                         llvm::ArrayRef<int64_t> offsets) {
+  auto ofrs = getAsIndexOpFoldResult(builder.getContext(), offsets);
+  build(builder, state, TensorDesc, source, ofrs);
+}
+
 LogicalResult CreateDescOp::verify() {
   auto tdescTy = getTensorDescType();
 
@@ -473,6 +492,29 @@ LogicalResult StoreScatterOp::verify() {
 
   return success();
 }
+
+//===----------------------------------------------------------------------===//
+// XeGPU_UpdateOffsetOp
+//===----------------------------------------------------------------------===//
+void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
+                           mlir::Value tensorDesc,
+                           llvm::ArrayRef<OpFoldResult> offsets) {
+  auto tdescTy = mlir::dyn_cast<TensorDescType>(tensorDesc.getType());
+  assert(tdescTy && "Expecting the source is a TensorDescType value.");
+  auto loc = tensorDesc.getLoc();
+  int64_t size = static_cast<int64_t>(offsets.size());
+  auto type = VectorType::get({size}, builder.getIndexType());
+  auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets);
+  auto offset = builder.create<vector::FromElementsOp>(loc, type, values);
+  build(builder, state, tdescTy, tensorDesc, offset);
+}
+
+void UpdateOffsetOp::build(OpBuilder &builder, OperationState &state,
+                           Value tensorDesc, llvm::ArrayRef<int64_t> offsets) {
+  auto ofrs = getAsIndexOpFoldResult(builder.getContext(), offsets);
+  build(builder, state, tensorDesc, ofrs);
+}
+
 //===----------------------------------------------------------------------===//
 // XeGPU_DpasOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
index c1126efb6046dc..6db57aad773aa8 100644
--- a/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
+++ b/mlir/test/Dialect/XeGPU/XeGPUOps.mlir
@@ -104,22 +104,28 @@ gpu.func @test_create_update_nd_tdesc_vc(%src: memref<24x32xf32>) {
 
 // CHECK: gpu.func @test_create_tdesc_vc(%[[arg0:.*]]: ui64) {
 gpu.func @test_create_tdesc_vc(%src: ui64) {
-  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %arg0 [0, 8, 16, 24] : ui64 -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
-  %1 = xegpu.create_tdesc %src[0, 8, 16, 24] : ui64  -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+  //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+  %1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex>  -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
   gpu.return
 }
 
 // CHECK: gpu.func @test_create_tdesc_vc_1(%[[arg0:.*]]: memref<?xf32, 3>) {
 gpu.func @test_create_tdesc_vc_1(%src: memref<?xf32, 3>) {
-  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %arg0 [0, 8, 16, 24] : memref<?xf32, 3> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<memory_space =  slm, chunk_size = 2 : i64>>
-  %1 = xegpu.create_tdesc %src[0, 8, 16, 24] : memref<?xf32, 3>  -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<memory_space = slm, chunk_size = 2>>
+  //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : memref<?xf32, 3>, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<memory_space =  slm, chunk_size = 2 : i64>>
+  %1 = xegpu.create_tdesc %src, %0 : memref<?xf32, 3>, vector<4xindex>  -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<memory_space = slm, chunk_size = 2>>
   gpu.return
 }
 
 // CHECK: gpu.func @test_prefetch_vc(%[[arg0:.*]]: ui64) {
 gpu.func @test_prefetch_vc(%src: ui64) {
-  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %arg0 [0, 8, 16, 24] : ui64 -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
-  %1 = xegpu.create_tdesc %src[0, 8, 16, 24] : ui64  -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+  //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+  %1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
   // CHECK: xegpu.prefetch %[[R0]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}> : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
   xegpu.prefetch %1 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>: !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
   gpu.return
@@ -129,8 +135,10 @@ gpu.func @test_prefetch_vc(%src: ui64) {
 gpu.func @test_load_gather_vc(%src: ui64) {
   //CHECK: %[[cst:.*]] = arith.constant dense<true> : vector<4xi1>
   %0 = arith.constant dense<1>: vector<4xi1>
-  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %arg0 [0, 8, 16, 24] : ui64 -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
-  %1 = xegpu.create_tdesc %src[0, 8, 16, 24] : ui64  -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+  //CHECK: %[[c2:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  %c = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[c2]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+  %1 = xegpu.create_tdesc %src, %c : ui64, vector<4xindex>  -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
   //CHECK: %[[R1:.*]] = xegpu.load %[[R0]], %[[cst]] <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}>
   //CHECK-SAME: !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<4xi1> -> vector<2x4xf32>
   %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>, transpose}>
@@ -144,8 +152,10 @@ gpu.func @test_store_scatter_vc(%src: ui64) {
   %0 = arith.constant dense<1>: vector<4xi1>
   //CHECK: %[[c1:.*]] = arith.constant dense<2.900000e+00> : vector<2x4xf32>
   %1 = arith.constant dense<2.9>: vector<2x4xf32>
-  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %arg0 [0, 8, 16, 24] : ui64 -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
-  %2 = xegpu.create_tdesc %src[0, 8, 16, 24] : ui64  -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+  //CHECK: %[[c2:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  %c = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[c2]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+  %2 = xegpu.create_tdesc %src, %c : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
   //CHECK: xegpu.store %[[c1]], %[[R0]], %[[c0]] <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}>
   //CHECK-SAME: vector<2x4xf32>, !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<4xi1>
   xegpu.store %1, %2, %0 <{l1_hint = #xegpu.cache_hint<write_back>, l2_hint = #xegpu.cache_hint<uncached>, transpose}>
@@ -155,10 +165,14 @@ gpu.func @test_store_scatter_vc(%src: ui64) {
 
 // CHECK: gpu.func @test_create_update_tdesc_vc(%[[arg0:.*]]: ui64) {
 gpu.func @test_create_update_tdesc_vc(%src: ui64) {
-  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %arg0 [0, 8, 16, 24] : ui64 -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
-  %1 = xegpu.create_tdesc %src[0, 8, 16, 24]: ui64  -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
-  //CHECK: %[[R1:.*]] = xegpu.update_offset %[[R0]], [32, 32, 32, 32] : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
-  %2 = xegpu.update_offset %1, [32, 32, 32, 32] : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+  //CHECK: %[[cst:.*]] = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[cst]] : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+  //CHECK: %[[st:.*]] = arith.constant dense<32> : vector<4xindex>
+  //CHECK: %[[R1:.*]] = xegpu.update_offset %[[R0]], %[[st]] : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<4xindex>
+  %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  %1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex> -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+  %s = arith.constant dense<[32, 32, 32, 32]> : vector<4xindex>
+  %2 = xegpu.update_offset %1, %s : !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xindex>
   gpu.return
 }
 
@@ -179,8 +193,10 @@ gpu.func @test_dpas_vc_with_packed_b(%a : vector<8x16xf16>, %b: vector<8x16x2xf1
 
 // CHECK: gpu.func @test_atomic_rmw(%[[arg0:.*]]: ui64, %[[arg1:.*]]: vector<16xf32>, %[[arg2:.*]]: vector<16xi1>)
 gpu.func @test_atomic_rmw(%src: ui64, %value : vector<16xf32>, %mask : vector<16xi1>) {
-  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : ui64 -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
-  %1 = xegpu.create_tdesc %src[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]: ui64 -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+  //CHECK: %[[c:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xindex>
+  %c = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xindex>
+  //CHECK: %[[R0:.*]] = xegpu.create_tdesc %[[arg0]], %[[c]] : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
+  %1 = xegpu.create_tdesc %src, %c: ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
   //CHECK: %[[R1:.*]] = xegpu.atomic_rmw addf %[[R0]], %[[arg2]], %[[arg1]] : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>, vector<16xf32> -> vector<16xf32>
   xegpu.atomic_rmw addf %1, %mask, %value: !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>, vector<16xf32> -> vector<16xf32>
   gpu.return

diff  --git a/mlir/test/Dialect/XeGPU/invalid.mlir b/mlir/test/Dialect/XeGPU/invalid.mlir
index 193dae352e3707..f8a0d95bd70a27 100644
--- a/mlir/test/Dialect/XeGPU/invalid.mlir
+++ b/mlir/test/Dialect/XeGPU/invalid.mlir
@@ -39,8 +39,9 @@ func.func @test_prefetch_nd_vc_1(%src: memref<24x32xf16>) {
 
 // -----
 func.func @test_prefetch_nd_vc_2(%src: memref<24xf16>) {
-  %1 = xegpu.create_tdesc %src[0, 1, 2, 3, 4, 5, 6, 7]
-        : memref<24xf16> -> !xegpu.tensor_desc<8xf16, #xegpu.scatter_tdesc_attr<>>
+  %0 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
+  %1 = xegpu.create_tdesc %src, %0 : memref<24xf16>, vector<8xindex>
+                -> !xegpu.tensor_desc<8xf16, #xegpu.scatter_tdesc_attr<>>
   // expected-error at +1 {{Expects a non-scattered TensorDesc}}
   xegpu.prefetch_nd %1 <{l1_hint = #xegpu.cache_hint<cached>}>
         : !xegpu.tensor_desc<8xf16, #xegpu.scatter_tdesc_attr<>>
@@ -58,8 +59,9 @@ func.func @test_load_nd_vc_1(%src: memref<8x16xf16>) {
 
 // -----
 func.func @test_load_nd_vc_2(%src: memref<16xf16>) {
-  %1 = xegpu.create_tdesc %src[0, 2, 4, 6, 8, 10, 12, 14]
-        : memref<16xf16> -> !xegpu.tensor_desc<8x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+  %0 = arith.constant dense<[0, 2, 4, 6, 8, 10, 12, 14]> : vector<8xindex>
+  %1 = xegpu.create_tdesc %src, %0 : memref<16xf16>, vector<8xindex>
+          -> !xegpu.tensor_desc<8x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
   // expected-error at +1 {{Expects a non-scattered TensorDesc.}}
   %2 = xegpu.load_nd %1 <{l1_hint = #xegpu.cache_hint<cached>}>
       : !xegpu.tensor_desc<8x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2>> -> vector<8x2xf16>
@@ -86,9 +88,10 @@ func.func @test_store_nd_vc_1(%dst: memref<24x32xf16>) {
 
 // -----
 func.func @test_store_nd_vc_2(%dst: memref<16xf16>) {
+  %0 = arith.constant dense<[0, 2, 4, 6, 8, 10, 12, 14]> : vector<8xindex>
   %1 = arith.constant dense<1.0>: vector<8x2xf16>
-  %2 = xegpu.create_tdesc %dst[0, 2, 4, 6, 8, 10, 12, 14]
-        : memref<16xf16> -> !xegpu.tensor_desc<8x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+  %2 = xegpu.create_tdesc %dst, %0 : memref<16xf16>, vector<8xindex>
+            -> !xegpu.tensor_desc<8x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
   // expected-error at +1 {{Expects a non-scattered TensorDesc}}
   xegpu.store_nd %1, %2 <{l1_hint = #xegpu.cache_hint<streaming>}>
         : vector<8x2xf16>, !xegpu.tensor_desc<8x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
@@ -97,8 +100,9 @@ func.func @test_store_nd_vc_2(%dst: memref<16xf16>) {
 
 // -----
 func.func @test_update_nd_offset_1(%dst: memref<16xf16>) {
-  %1 = xegpu.create_tdesc %dst[0, 2, 4, 6, 8, 10, 12, 14]
-        : memref<16xf16> -> !xegpu.tensor_desc<8x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+  %0 = arith.constant dense<[0, 2, 4, 6, 8, 10, 12, 14]> : vector<8xindex>
+  %1 = xegpu.create_tdesc %dst, %0 : memref<16xf16>, vector<8xindex>
+            -> !xegpu.tensor_desc<8x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
   // expected-error at +1 {{Expects a non-scattered TensorDesc}}
   xegpu.update_nd_offset %1, [0, 2] : !xegpu.tensor_desc<8x2xf16, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
   return
@@ -106,24 +110,27 @@ func.func @test_update_nd_offset_1(%dst: memref<16xf16>) {
 
 // -----
 func.func @test_create_tdesc_vc_1(%src: ui64) {
+  %0 = arith.constant dense<[0, 2, 4, 6, 8, 10, 12, 14]> : vector<8xindex>
   // expected-error at +1 {{Expects a scattered TensorDesc}}
-  %1 = xegpu.create_tdesc %src[0, 2, 4, 6, 8, 10, 12, 14]
-        : ui64 -> !xegpu.tensor_desc<8xf16>
+  %1 = xegpu.create_tdesc %src, %0 : ui64, vector<8xindex> -> !xegpu.tensor_desc<8xf16>
   return
 }
 
 // -----
 func.func @test_create_tdesc_vc_2(%src: ui64) {
+  %0 = arith.constant dense<[0, 2, 4, 6, 8, 10, 12, 14]> : vector<8xindex>
   // expected-error at +1 {{Incorrect TensorDesc shape}}
-  %1 = xegpu.create_tdesc %src[0, 2, 4, 6, 8, 10, 12, 14] {chunk_size = 2}
-        : ui64 -> !xegpu.tensor_desc<8x4xf16, #xegpu.scatter_tdesc_attr<>>
+  %1 = xegpu.create_tdesc %src, %0 : ui64, vector<8xindex>
+          -> !xegpu.tensor_desc<8x4xf16, #xegpu.scatter_tdesc_attr<>>
   return
 }
 
 // -----
 func.func @test_create_tdesc_vc_1(%src: memref<?xf32>) {
+  %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
   // expected-error at +1 {{Memory space mismatch}}
-  %1 = xegpu.create_tdesc %src[0, 8, 16, 24] : memref<?xf32>  -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<memory_space = slm, chunk_size = 2>>
+  %1 = xegpu.create_tdesc %src, %0 : memref<?xf32>, vector<4xindex>
+          -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<memory_space = slm, chunk_size = 2>>
   return
 }
 
@@ -137,7 +144,9 @@ func.func @test_prefetch_vc_1(%src: memref<24x32xf16>) {
 
 // -----
 func.func @test_prefetch_vc_2(%src: ui64) {
-  %1 = xegpu.create_tdesc %src[0, 8, 16, 24] : ui64  -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+  %0 = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
+  %1 = xegpu.create_tdesc %src, %0 : ui64, vector<4xindex>
+          -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
   // expected-error at +1 {{invlid l1_hint: #xegpu.cache_hint<write_back>}}
   xegpu.prefetch %1 <{l1_hint = #xegpu.cache_hint<write_back>}>: !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
   return
@@ -155,8 +164,9 @@ func.func @test_load_gather_vc_1(%src: memref<24x32xf16>) {
 
 // -----
 func.func @test_load_gather_vc_2(%src: ui64) {
+  %cst = arith.constant dense<[0, 8, 16, 24]> : vector<4xindex>
   %0 = arith.constant dense<1>: vector<4xi1>
-  %1 = xegpu.create_tdesc %src[0, 8, 16, 24] : ui64
+  %1 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex>
         -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
   // expected-error at +1 {{invlid l1_hint: #xegpu.cache_hint<write_back>}}
   %2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<write_back>}>
@@ -178,10 +188,11 @@ func.func @test_store_scatter_vc_1(%src: memref<24x32xf32>) {
 
 // -----
 func.func @test_store_scatter_vc_2(%src: ui64) {
+  %cst = arith.constant dense<[0, 8, 16, 24]>: vector<4xindex>
   %0 = arith.constant dense<1>: vector<4xi1>
   %1 = arith.constant dense<2.9>: vector<4x2xf32>
-  %2 = xegpu.create_tdesc %src[0, 8, 16, 24]
-          : ui64 -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
+  %2 = xegpu.create_tdesc %src, %cst : ui64, vector<4xindex>
+              -> !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>
   // expected-error at +1 {{invlid l1_hint: #xegpu.cache_hint<streaming>}}
   xegpu.store %1, %2, %0 <{l1_hint = #xegpu.cache_hint<streaming>}> : vector<4x2xf32>,
           !xegpu.tensor_desc<4x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2>>, vector<4xi1>
@@ -204,7 +215,8 @@ func.func @test_dpas_vc_2(%a : vector<8x8x2xf16>, %b: vector<8x16x2xf16>) {
 
 // -----
 func.func @test_atomic_rmw(%src: ui64, %value : vector<16x4xf32>, %mask : vector<16xi1>) {
-  %1 = xegpu.create_tdesc %src[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : ui64 -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8>>
+  %0 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xindex>
+  %1 = xegpu.create_tdesc %src, %0 : ui64, vector<16xindex> -> !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8>>
   // expected-error at +1 {{failed to verify that all of {tensorDesc, value, result} have same shape}}
   xegpu.atomic_rmw addf %1, %mask, %value: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<chunk_size = 8>>, vector<16xi1>, vector<16x4xf32> -> vector<16x8xf32>
   return


        


More information about the Mlir-commits mailing list