[Mlir-commits] [mlir] [mlir][xegpu] Add definitons of MatrixDescType and related ops. (PR #153273)
Chao Chen
llvmlistbot at llvm.org
Thu Aug 14 14:09:22 PDT 2025
================
@@ -751,4 +751,72 @@ gpu.func @fence() {
gpu.return
}
+// CHECK-LABEL: gpu.func @create_matrix_desc({{.*}}) {
+gpu.func @create_matrix_desc() {
+ //CHECK: [[alloc:%.+]] = memref.alloca() {alignment = 1024 : i64} : memref<2048xi8, 3>
+ //CHECK: [[mdesc:%.+]] = xegpu.create_matrix_desc [[alloc]] : memref<2048xi8, 3> -> !xegpu.matrix_desc<16x64xf16>
+ %m = memref.alloca() {alignment = 1024} : memref<2048xi8, 3>
+ %matrix_desc = xegpu.create_matrix_desc %m : memref<2048xi8, 3> -> !xegpu.matrix_desc<16x64xf16>
+ gpu.return
+}
+
+// CHECK-LABEL: gpu.func @create_matrix_desc_with_stride({{.*}}) {
+gpu.func @create_matrix_desc_with_stride() {
+ //CHECK: [[alloc:%.+]] = memref.alloca() {alignment = 1024 : i64} : memref<2048xi8, 3>
+ //CHECK: [[mdesc:%.+]] = xegpu.create_matrix_desc [[alloc]] : memref<2048xi8, 3> -> !xegpu.matrix_desc<16x64xf16, strided<[1, 16]>>
+ %m = memref.alloca() {alignment = 1024} : memref<2048xi8, 3>
+ %matrix_desc = xegpu.create_matrix_desc %m : memref<2048xi8, 3> -> !xegpu.matrix_desc<16x64xf16, strided<[1, 16]>>
+ gpu.return
+}
+
+// CHECK: gpu.func @load_matrix_desc([[ARG0:%.+]]: !xegpu.matrix_desc<16x64xf16>)
+gpu.func @load_matrix_desc(%arg0: !xegpu.matrix_desc<16x64xf16>) {
+ // CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.matrix_desc<16x64xf16> -> vector<8x16xf16>
+ %data = xegpu.load_matrix %arg0[8, 8]: !xegpu.matrix_desc<16x64xf16> -> vector<8x16xf16>
+ gpu.return
+}
+
+// CHECK: gpu.func @load_matrix_desc_with_stride(%arg0: !xegpu.matrix_desc<16x64xf16, strided<[1, 16]>>)
+gpu.func @load_matrix_desc_with_stride(%arg0: !xegpu.matrix_desc<16x64xf16, strided<[1, 16]>>) {
+ // CHECK: xegpu.load_matrix [[ARG0]][8, 8] : !xegpu.matrix_desc<16x64xf16, strided<[1, 16]>> -> vector<8x16xf16>
+ %data = xegpu.load_matrix %arg0[8, 8]: !xegpu.matrix_desc<16x64xf16, strided<[1, 16]>> -> vector<8x16xf16>
+ gpu.return
+}
+
+
+// CHECK: gpu.func @store_matrix_desc([[ARG0:%.+]]: !xegpu.matrix_desc<16x64xf16>, [[ARG1:%.+]]: vector<16x16xf16>)
+gpu.func @store_matrix_desc(%arg0: !xegpu.matrix_desc<16x64xf16>, %arg1: vector<16x16xf16>) {
+ // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 8] : vector<16x16xf16>, !xegpu.matrix_desc<16x64xf16>
+ xegpu.store_matrix %arg1, %arg0[8, 8]: vector<16x16xf16>, !xegpu.matrix_desc<16x64xf16>
+ gpu.return
+}
+
+// CHECK: gpu.func @store_matrix_desc_with_stride([[ARG0:%.+]]: !xegpu.matrix_desc<16x64xf16, strided<[1, 16]>>, [[ARG1:%.+]]: vector<16x16xf16>)
+gpu.func @store_matrix_desc_with_stride(%arg0: !xegpu.matrix_desc<16x64xf16, strided<[1, 16]>>, %arg1: vector<16x16xf16>) {
+ // CHECK: xegpu.store_matrix [[ARG1]], [[ARG0]][8, 8] : vector<16x16xf16>, !xegpu.matrix_desc<16x64xf16, strided<[1, 16]>>
+ xegpu.store_matrix %arg1, %arg0[8, 8]: vector<16x16xf16>, !xegpu.matrix_desc<16x64xf16, strided<[1, 16]>>
----------------
chencha3 wrote:
This format needs two attribute fields. It is a little bit hard to do the extension in downstream. I create a `MemLayoutAttr` to encode them, with format as `!xegpu.matrix_desc<8x16xf16, #xegpu.mem_layout<stride = [64, 1]>>`, is it good to you?
https://github.com/llvm/llvm-project/pull/153273
More information about the Mlir-commits
mailing list