[Mlir-commits] [mlir] [mlir][xegpu] add support for structure control flow ops in workgroup to subgroup distribution (PR #142618)

Nishant Patel llvmlistbot at llvm.org
Tue Jun 10 15:52:43 PDT 2025


================
@@ -169,4 +169,125 @@ gpu.func @test_dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
       : vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32>
     gpu.return
   }
+
+  gpu.func @test_scf_for(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) {
+    //CHECK: [[c0:%.+]] = arith.constant 0 : index
+    //CHECK: [[c128:%.+]] = arith.constant 128 : index
+    //CHECK: [[c1024:%.+]] = arith.constant 1024 : index
+    %c0 = arith.constant 0 : index
+    %c128 = arith.constant 128 : index
+    %c1024 = arith.constant 1024 : index
+    %block_id_x = gpu.block_id  x
+    %block_id_y = gpu.block_id  y
+    %0 = arith.muli %block_id_x, %c128 : index
+    %1 = arith.muli %block_id_y, %c128 : index
+    %2 = xegpu.create_nd_tdesc %arg2[%0, %1] : memref<1024x1024xf32> -> !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>>
+    %3 = xegpu.load_nd %2 : !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>> -> vector<128x128xf32>
+    %4 = xegpu.create_nd_tdesc %arg0[%0, %c0] : memref<1024x1024xf16> -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>>
+    %5 = xegpu.create_nd_tdesc %arg1[%c0, %1] : memref<1024x1024xf16> -> !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>>
+
+    //CHECK: [[scf:%.+]]:3 = scf.for [[arg3:%.+]] = [[c0]] to [[c1024]] step [[c128]] iter_args([[arg4:%.+]] = {{.*}}, [[arg5:%.+]] = {{.*}}, [[arg6:%.+]] = {{.*}}) -> (!xegpu.tensor_desc<16x128xf16>, !xegpu.tensor_desc<128x16xf16>, vector<16x16xf32>)
+    //CHECK: [[a:%.+]] = xegpu.load_nd [[arg4]] : !xegpu.tensor_desc<16x128xf16> -> vector<16x128xf16>
+    //CHECK: [[b:%.+]] = xegpu.load_nd [[arg5]] : !xegpu.tensor_desc<128x16xf16> -> vector<128x16xf16>
+    //CHECK: [[c:%.+]] = xegpu.dpas [[a]], [[b]], [[arg6]] : vector<16x128xf16>, vector<128x16xf16>, vector<16x16xf32> -> vector<16x16xf32>
+    //CHECK: [[at:%.+]] = xegpu.update_nd_offset [[arg4]], [[[c0]], [[c128]]] : !xegpu.tensor_desc<16x128xf16>
+    //CHECK: [[bt:%.+]] = xegpu.update_nd_offset [[arg5]], [[[c128]], [[c0]]] : !xegpu.tensor_desc<128x16xf16>
+    //CHECK: scf.yield [[at]], [[bt]], [[c]] : !xegpu.tensor_desc<16x128xf16>, !xegpu.tensor_desc<128x16xf16>, vector<16x16xf32>
+    %6:3 = scf.for %arg3 = %c0 to %c1024 step %c128 iter_args(%arg4 = %4, %arg5 = %5, %arg6 = %3) -> (!xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>>, !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>>, vector<128x128xf32>) {
+      %8 = xegpu.load_nd %arg4  : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>> -> vector<128x128xf16>
+      %9 = xegpu.load_nd %arg5  : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>> -> vector<128x128xf16>
+      %10 = xegpu.dpas %8, %9, %arg6 {layout_result_0 = #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>} : vector<128x128xf16>, vector<128x128xf16>, vector<128x128xf32> -> vector<128x128xf32>
+      %11 = xegpu.update_nd_offset %arg4, [%c0, %c128] : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>>
+      %12 = xegpu.update_nd_offset %arg5, [%c128, %c0] : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>>
+      scf.yield %11, %12, %10 : !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 128]>>, !xegpu.tensor_desc<128x128xf16, #xegpu.layout<sg_layout = [8, 8], sg_data = [128, 16]>>, vector<128x128xf32>
+    }
+    %7 = xegpu.create_nd_tdesc %arg2[%0, %1] : memref<1024x1024xf32> -> !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>>
+    xegpu.store_nd %6#2, %7  : vector<128x128xf32>, !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16]>>
+    gpu.return
+  }
+
+  gpu.func @test_scf_while_and_condition(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) {
+    %c1_i32 = arith.constant 1 : i32
+    %c10_i32 = arith.constant 10 : i32
+    %c0_i32 = arith.constant 0 : i32
+    %0 = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
+    %1 = xegpu.load_nd %0  : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>> -> vector<256xf32>
+    %2 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
+
+    // CHECK: scf.while {{.*}} : (vector<16xf32>, i32) -> (vector<16xf32>, i32)
+    %3:2 = scf.while (%arg2 = %1, %arg3 = %c0_i32) : (vector<256xf32>, i32) -> (vector<256xf32>, i32) {
+      %4 = arith.cmpi slt, %arg3, %c10_i32 : i32
+      // CHECK: scf.condition{{.*}} : vector<16xf32>, i32
+      scf.condition(%4) %arg2, %arg3 : vector<256xf32>, i32
+    } do {
+    // CHECK: ([[arg2:%.+]]: vector<16xf32>, [[arg3:%.+]]: i32)
+    ^bb0(%arg2: vector<256xf32>, %arg3: i32):
+      xegpu.store_nd %arg2, %2  : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
+      %4 = arith.addi %arg3, %c1_i32 : i32
+      %5 = xegpu.update_nd_offset %0, [256] : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
+      %6 = xegpu.load_nd %5  : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>> -> vector<256xf32>
+      scf.yield %6, %4 : vector<256xf32>, i32
+    }
+    gpu.return
+  }
+
+  gpu.func @test_scf_if(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) {
+    %c10 = arith.constant 10 : index
+    %id = gpu.subgroup_id : index
+
+    %0 = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
+    %1 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
+
+    %4 = arith.cmpi eq, %id, %c10 : index
+    // CHECK-LABEL: scf.if
+    //  CHECK-SAME: (vector<16xf32>)
+    %5 = scf.if %4 -> (vector<256xf32>) {
+      // CHECK-LABEL: xegpu.load_nd
+      //  CHECK-SAME: !xegpu.tensor_desc<16xf32> -> vector<16xf32>
+      %2 = xegpu.load_nd %0 : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>> -> vector<256xf32>
+      // CHECK-LABEL: scf.yield
+      //  CHECK-SAME: vector<16xf32>
+      scf.yield %2 : vector<256xf32>
+    } else {
+      // CHECK-LABEL: xegpu.load_nd
+      //  CHECK-SAME: !xegpu.tensor_desc<16xf32> -> vector<16xf32>
+      %3 = xegpu.load_nd %1 : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>> -> vector<256xf32>
+      // CHECK-LABEL: scf.yield
+      //  CHECK-SAME: vector<16xf32>
+      scf.yield %3 : vector<256xf32>
+    } {layout_result_0 = #xegpu.layout<sg_layout = [16], sg_data = [16]>}
+    xegpu.store_nd %5, %0 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
+    gpu.return
+  }
+
+  gpu.func @test_scf_if_tensor_desc(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) {
+    %c10 = arith.constant 10 : index
+    %id = gpu.subgroup_id : index
+
+    %t = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
+    %d = xegpu.load_nd %t : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>> -> vector<256xf32>
+
+    %0 = arith.cmpi eq, %id, %c10 : index
+    // CHECK-LABEL: scf.if
+    //  CHECK-SAME: (!xegpu.tensor_desc<16xf32>)
+    %1 = scf.if %0 -> (!xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>) {
+      // CHECK-LABEL: xegpu.create_nd_tdesc
+      //  CHECK-SAME: memref<1024xf32> -> !xegpu.tensor_desc<16xf32>
+      %2 = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
+      // CHECK-LABEL: scf.yield
+      //  CHECK-SAME: !xegpu.tensor_desc<16xf32>
+      scf.yield %2 : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
+    } else {
+      // CHECK-LABEL: xegpu.create_nd_tdesc
+      //  CHECK-SAME: memref<1024xf32> -> !xegpu.tensor_desc<16xf32>
+      %3 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
+      // CHECK-LABEL: scf.yield
+      //  CHECK-SAME: !xegpu.tensor_desc<16xf32>
+      scf.yield %3 : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
+    }
+    xegpu.store_nd %d, %1 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
+    gpu.return
+  }
+
+
----------------
nbpatel wrote:

ok we can add it later

https://github.com/llvm/llvm-project/pull/142618


More information about the Mlir-commits mailing list