[Mlir-commits] [mlir] [mlir][xegpu] Add definitons of MatrixDescType and related ops. (PR #153273)

Jianhui Li llvmlistbot at llvm.org
Wed Aug 13 15:36:40 PDT 2025


================
@@ -751,4 +751,65 @@ gpu.func @fence() {
   gpu.return
 }
 
+// CHECK-LABEL: gpu.func @create_matrix_desc({{.*}}) {
+gpu.func @create_matrix_desc() {
+  //CHECK: [[alloc:%.+]] = memref.alloca() {alignment = 1024 : i64} : memref<2048xi8, 3>
+  //CHECK: [[mdesc:%.+]] = xegpu.create_matrix_desc [[alloc]] : memref<2048xi8, 3> -> !xegpu.matrix_desc<16x64xf16>
+  %m = memref.alloca() {alignment = 1024} : memref<2048xi8, 3>
+  %matrix_desc = xegpu.create_matrix_desc %m : memref<2048xi8, 3> -> !xegpu.matrix_desc<16x64xf16>
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @create_matrix_desc_with_stride({{.*}}) {
+gpu.func @create_matrix_desc_with_stride() {
+  //CHECK: [[alloc:%.+]] = memref.alloca() {alignment = 1024 : i64} : memref<2048xi8, 3>
+  //CHECK: [[mdesc:%.+]] = xegpu.create_matrix_desc [[alloc]] : memref<2048xi8, 3> -> !xegpu.matrix_desc<16x64xf16, strided<[1, 16]>>
+  %m = memref.alloca() {alignment = 1024} : memref<2048xi8, 3>
+  %matrix_desc = xegpu.create_matrix_desc %m : memref<2048xi8, 3> -> !xegpu.matrix_desc<16x64xf16, strided<[1, 16]>>
+  gpu.return
+}
+
+// CHECK: gpu.func @load_matrix_desc([[ARG0:%.+]]: !xegpu.matrix_desc<16x64xf16>)
+gpu.func @load_matrix_desc(%arg0: !xegpu.matrix_desc<16x64xf16>) {
+  // CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.matrix_desc<16x64xf16> -> vector<8x16xf16>
+  %data = xegpu.load_matrix %arg0[8, 8]: !xegpu.matrix_desc<16x64xf16> -> vector<8x16xf16>
+  gpu.return
+}
+
+// CHECK: gpu.func @load_matrix_desc_with_stride(%arg0: !xegpu.matrix_desc<16x64xf16, strided<[1, 16]>>)
+gpu.func @load_matrix_desc_with_stride(%arg0: !xegpu.matrix_desc<16x64xf16, strided<[1, 16]>>) {
+  // CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.matrix_desc<16x64xf16, strided<[1, 16]>> -> vector<8x16xf16>
+  %data = xegpu.load_matrix %arg0[8, 8]: !xegpu.matrix_desc<16x64xf16, strided<[1, 16]>> -> vector<8x16xf16>
+  gpu.return
+}
+
+
+// CHECK: gpu.func @store_matrix_desc([[ARG0:%.+]]: !xegpu.matrix_desc<16x64xf16>, [[ARG1:%.+]]: vector<16x16xf16>)
+gpu.func @store_matrix_desc(%arg0: !xegpu.matrix_desc<16x64xf16>, %arg1: vector<16x16xf16>) {
+  // CHECK: xegpu.store_matrix [[ARG0]][8, 8], [[ARG1]] : !xegpu.matrix_desc<16x64xf16>, vector<16x16xf16>
+  xegpu.store_matrix %arg0[8, 8], %arg1: !xegpu.matrix_desc<16x64xf16>, vector<16x16xf16>
+  gpu.return
+}
+
+// CHECK: gpu.func @store_matrix_desc_with_stride([[ARG0:%.+]]: !xegpu.matrix_desc<16x64xf16, strided<[1, 16]>>, [[ARG1:%.+]]: vector<16x16xf16>)
+gpu.func @store_matrix_desc_with_stride(%arg0: !xegpu.matrix_desc<16x64xf16, strided<[1, 16]>>, %arg1: vector<16x16xf16>) {
+  // CHECK: xegpu.store_matrix [[ARG0]][8, 8], [[ARG1]] : !xegpu.matrix_desc<16x64xf16, strided<[1, 16]>>, vector<16x16xf16>
+  xegpu.store_matrix %arg0[8, 8], %arg1: !xegpu.matrix_desc<16x64xf16, strided<[1, 16]>>, vector<16x16xf16>
----------------
Jianhui-Li wrote:

maybe keep vector<16x16xf16> as vector<8x16xf16>

https://github.com/llvm/llvm-project/pull/153273


More information about the Mlir-commits mailing list