[Mlir-commits] [mlir] [mlir][xegpu] Relax rank restriction of TensorDescType (PR #145916)

Chao Chen llvmlistbot at llvm.org
Wed Jul 9 12:55:29 PDT 2025


================
@@ -446,4 +446,87 @@ gpu.module @test_kernel   {
   }
 }
 
+// -----
+#l = #xegpu.layout<inst_data = [8,32,16]>
+gpu.module @test_kernel {
+  gpu.func @test_3d_block_tensor_desc(%A: memref<1024x1024x1024xf16>, %B: memref<1024x1024x1024xf16>, %C: memref<1024x1024x1024xf16>) {
+    %c0 = arith.constant 0 : index
+    %c32 = arith.constant 32 : index
+    %c1024 = arith.constant 1024 : index
+    %block_id_x = gpu.block_id x
+    %m = arith.muli %block_id_x, %c32 : index
+
+    %a_tdesc = xegpu.create_nd_tdesc %A[%m, %m, %c0] : memref<1024x1024x1024xf16> -> !xegpu.tensor_desc<32x32x32xf16, #l>
+    %b_tdesc = xegpu.create_nd_tdesc %B[%m, %m, %c0] : memref<1024x1024x1024xf16> -> !xegpu.tensor_desc<32x32x32xf16, #l>
+    %c_tdesc = xegpu.create_nd_tdesc %C[%m, %m, %c0] : memref<1024x1024x1024xf16> -> !xegpu.tensor_desc<32x32x32xf16, #l>
+
+    %out:3 = scf.for %k = %c0 to %c1024 step %c32
+      iter_args(%arg0 = %a_tdesc, %arg1 = %b_tdesc, %arg2 = %c_tdesc)
+      -> (!xegpu.tensor_desc<32x32x32xf16, #l>, !xegpu.tensor_desc<32x32x32xf16, #l>, !xegpu.tensor_desc<32x32x32xf16, #l>) {
+      //CHECK-COUNT-16: xegpu.load_nd {{.*}}  : !xegpu.tensor_desc<8x32x16xf16> -> vector<8x32x16xf16>
+      %a = xegpu.load_nd %arg0 : !xegpu.tensor_desc<32x32x32xf16, #l> -> vector<32x32x32xf16>
+      %b = xegpu.load_nd %arg1 : !xegpu.tensor_desc<32x32x32xf16, #l> -> vector<32x32x32xf16>
+
+      //CHECK-COUNT-8: arith.addf {{.*}} : vector<8x32x16xf16>
+      %c = arith.addf %a, %b {layout_result_0 = #l} : vector<32x32x32xf16>
+
+      //CHECK-COUNT-8: xegpu.store_nd {{.*}} : vector<8x32x16xf16>, !xegpu.tensor_desc<8x32x16xf16>
+      xegpu.store_nd %c, %arg2: vector<32x32x32xf16>, !xegpu.tensor_desc<32x32x32xf16, #l>
+
+      //CHECK-COUNT-24: xegpu.update_nd_offset {{.*}} : !xegpu.tensor_desc<8x32x16xf16>
+      %a_next_tdesc = xegpu.update_nd_offset %arg0, [0, 0, %c32] : !xegpu.tensor_desc<32x32x32xf16, #l>
+      %b_next_tdesc = xegpu.update_nd_offset %arg1, [0, 0, %c32] : !xegpu.tensor_desc<32x32x32xf16, #l>
+      %c_next_tdesc = xegpu.update_nd_offset %arg2, [0, 0, %c32] : !xegpu.tensor_desc<32x32x32xf16, #l>
+      scf.yield %a_next_tdesc, %b_next_tdesc, %c_next_tdesc
+        : !xegpu.tensor_desc<32x32x32xf16, #l>, !xegpu.tensor_desc<32x32x32xf16, #l>, !xegpu.tensor_desc<32x32x32xf16, #l>
+    }
+    gpu.return
+  }
+}
 
+// -----
+#l = #xegpu.layout<inst_data = [2, 8, 2]>
+gpu.module @test_kernel   {
+  // CHECK-LABEL: test_3d_scattered_tensor_desc
+  // CHECK-SAME: [[arg0:%.+]]: ui64
+  // CHECK-COUNT-4: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<2x8xindex> -> !xegpu.tensor_desc<2x8x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+  // CHECK-COUNT-4: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<2x8x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>
+   // CHECK-COUNT-4: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<2x8x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<2x8xindex>
+   // CHECK-COUNT-4: xegpu.load  {{.*}} : !xegpu.tensor_desc<2x8x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<2x8xi1> -> vector<2x8x2xf32>
+  // CHECK-COUNT-4: xegpu.store  {{.*}} : vector<2x8x2xf32>, !xegpu.tensor_desc<2x8x2xf32, #xegpu.scatter_tdesc_attr<chunk_size = 2 : i64>>, vector<2x8xi1>
+
+
+  gpu.func @test_3d_scattered_tensor_desc(%src: ui64)  {
+
+    %cst = arith.constant dense<[
----------------
chencha3 wrote:

Yes, added. 

https://github.com/llvm/llvm-project/pull/145916


More information about the Mlir-commits mailing list