[Mlir-commits] [mlir] [MLIR][XeGPU] Add unroll patterns and blocking pass for XeGPU [2/N] (PR #140163)
Chao Chen
llvmlistbot at llvm.org
Tue May 27 11:26:21 PDT 2025
================
@@ -0,0 +1,248 @@
+// RUN: mlir-opt --xegpu-blocking -split-input-file %s | FileCheck %s
+
+#a = #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [8, 1]>
+#b = #xegpu.layout<inst_data = [16, 16], lane_layout = [1, 16], lane_data = [16, 1]>
+#c = #xegpu.layout<inst_data = [8, 16], lane_layout = [1, 16], lane_data = [8, 1]>
+gpu.module @test_kernel {
+ gpu.func @test_gemm_with_one_to_n_lowering(%A: memref<1024x1024xf16>, %B: memref<1024x1024xf16>, %C: memref<1024x1024xf32>) {
+ %c0 = arith.constant 0 : index
+ %c16 = arith.constant 16 : index
+ %c32 = arith.constant 32 : index
+ %c1024 = arith.constant 1024 : index
+ %block_id_x = gpu.block_id x
+ %block_id_y = gpu.block_id y
+ %m = arith.muli %block_id_x, %c16 : index
+ %n = arith.muli %block_id_y, %c32 : index
+
+ %c_tdesc = xegpu.create_nd_tdesc %C[%m, %n] : memref<1024x1024xf32> -> !xegpu.tensor_desc<16x32xf32, #c>
+ %c_init = xegpu.load_nd %c_tdesc : !xegpu.tensor_desc<16x32xf32, #c> -> vector<16x32xf32>
+
+ %a_tdesc = xegpu.create_nd_tdesc %A[%m, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<16x32xf16, #a>
+ %b_tdesc = xegpu.create_nd_tdesc %B[%c0, %n] : memref<1024x1024xf16> -> !xegpu.tensor_desc<32x32xf16, #b>
+ %out:3 = scf.for %k = %c0 to %c1024 step %c32
+ iter_args(%arg0 = %a_tdesc, %arg1 = %b_tdesc, %arg2 = %c_init)
+ -> (!xegpu.tensor_desc<16x32xf16, #a>, !xegpu.tensor_desc<32x32xf16, #b>, vector<16x32xf32>) {
+ //CHECK-COUNT-4: xegpu.load_nd {{.*}} -> vector<8x16xf16>
+ %a = xegpu.load_nd %arg0 : !xegpu.tensor_desc<16x32xf16, #a> -> vector<16x32xf16>
+ //CHECK-COUNT-4: xegpu.load_nd {{.*}} -> vector<16x16xf16>
+ %b = xegpu.load_nd %arg1 : !xegpu.tensor_desc<32x32xf16, #b> -> vector<32x32xf16>
+ //CHECK-COUNT-8: xegpu.dpas {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [8, 1]>} : vector<8x16xf16>, vector<16x16xf16>, vector<8x16xf32> -> vector<8x16xf32>
----------------
chencha3 wrote:
I intend to have this test with each dimension of a vector/TensorDesc having more than one distribution. It takes factor of 2, which is minimum in this case.
https://github.com/llvm/llvm-project/pull/140163
More information about the Mlir-commits
mailing list