[Mlir-commits] [mlir] [mlir][xegpu] Add definitons of MatrixDescType and related ops. (PR #153273)

Chao Chen llvmlistbot at llvm.org
Thu Aug 14 14:09:22 PDT 2025


================
@@ -751,4 +751,72 @@ gpu.func @fence() {
   gpu.return
 }
 
+// CHECK-LABEL: gpu.func @create_matrix_desc({{.*}}) {
+gpu.func @create_matrix_desc() {
+  //CHECK: [[alloc:%.+]] = memref.alloca() {alignment = 1024 : i64} : memref<2048xi8, 3>
+  //CHECK: [[mdesc:%.+]] = xegpu.create_matrix_desc [[alloc]] : memref<2048xi8, 3> -> !xegpu.matrix_desc<16x64xf16>
+  %m = memref.alloca() {alignment = 1024} : memref<2048xi8, 3>
+  %matrix_desc = xegpu.create_matrix_desc %m : memref<2048xi8, 3> -> !xegpu.matrix_desc<16x64xf16>
+  gpu.return
+}
+
+// CHECK-LABEL: gpu.func @create_matrix_desc_with_stride({{.*}}) {
+gpu.func @create_matrix_desc_with_stride() {
+  //CHECK: [[alloc:%.+]] = memref.alloca() {alignment = 1024 : i64} : memref<2048xi8, 3>
+  //CHECK: [[mdesc:%.+]] = xegpu.create_matrix_desc [[alloc]] : memref<2048xi8, 3> -> !xegpu.matrix_desc<16x64xf16, strided<[1, 16]>>
+  %m = memref.alloca() {alignment = 1024} : memref<2048xi8, 3>
+  %matrix_desc = xegpu.create_matrix_desc %m : memref<2048xi8, 3> -> !xegpu.matrix_desc<16x64xf16, strided<[1, 16]>>
+  gpu.return
+}
+
+// CHECK: gpu.func @load_matrix_desc([[ARG0:%.+]]: !xegpu.matrix_desc<16x64xf16>)
+gpu.func @load_matrix_desc(%arg0: !xegpu.matrix_desc<16x64xf16>) {
+  // CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.matrix_desc<16x64xf16> -> vector<8x16xf16>
+  %data = xegpu.load_matrix %arg0[8, 8]: !xegpu.matrix_desc<16x64xf16> -> vector<8x16xf16>
+  gpu.return
+}
+
+// CHECK: gpu.func @load_matrix_desc_with_stride(%arg0: !xegpu.matrix_desc<16x64xf16, strided<[1, 16]>>)
+gpu.func @load_matrix_desc_with_stride(%arg0: !xegpu.matrix_desc<16x64xf16, strided<[1, 16]>>) {
+  // CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.matrix_desc<16x64xf16, strided<[1, 16]>> -> vector<8x16xf16>
+  %data = xegpu.load_matrix %arg0[8, 8]: !xegpu.matrix_desc<16x64xf16, strided<[1, 16]>> -> vector<8x16xf16>
+  gpu.return
+}
+
+
+// CHECK: gpu.func @store_matrix_desc([[ARG0:%.+]]: !xegpu.matrix_desc<16x64xf16>, [[ARG1:%.+]]: vector<16x16xf16>)
+gpu.func @store_matrix_desc(%arg0: !xegpu.matrix_desc<16x64xf16>, %arg1: vector<16x16xf16>) {
+  // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 8] : vector<16x16xf16>, !xegpu.matrix_desc<16x64xf16>
+  xegpu.store_matrix %arg1, %arg0[8, 8]: vector<16x16xf16>, !xegpu.matrix_desc<16x64xf16>
+  gpu.return
+}
+
+// CHECK: gpu.func @store_matrix_desc_with_stride([[ARG0:%.+]]: !xegpu.matrix_desc<16x64xf16, strided<[1, 16]>>, [[ARG1:%.+]]: vector<16x16xf16>)
+gpu.func @store_matrix_desc_with_stride(%arg0: !xegpu.matrix_desc<16x64xf16, strided<[1, 16]>>, %arg1: vector<16x16xf16>) {
+  // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 8] : vector<16x16xf16>, !xegpu.matrix_desc<16x64xf16, strided<[1, 16]>>
+  xegpu.store_matrix %arg1, %arg0[8, 8]: vector<16x16xf16>, !xegpu.matrix_desc<16x64xf16, strided<[1, 16]>>
----------------
chencha3 wrote:

This format needs two attribute fields. It is a little bit hard to do the extension in downstream. I create a `MemLayoutAttr` to encode them, with format as `!xegpu.matrix_desc<8x16xf16, #xegpu.mem_layout<stride = [64, 1]>>`, is it good to you?

https://github.com/llvm/llvm-project/pull/153273


More information about the Mlir-commits mailing list