<table border="1" cellspacing="0" cellpadding="8">
<a href=https://github.com/llvm/llvm-project/issues/87667>87667</a>
[mlir] [Vector] implement vector.extract for 2-d vector type (or flatten to 1-d an array type)
The VectorConvertToLLVM pass is not converting the %85 = vector.extract %75[0, %83] : i1 from vector<1x16xi1>
to LLVM dialect. It appers the conversion code converts only 1-d types.
module {
func.func @custom_call_topk_tuple_16_dispatch_0_topk_1x32xf32() {
%cst = arith.constant dense<false> : vector<1xi1>
%cst_0 = arith.constant dense<true> : vector<16xi1>
%c0_i32 = arith.constant 0 : i32
%false = arith.constant false
%c4 = arith.constant 4 : index
%cst_1 = arith.constant 0.000000e+00 : f32
%true = arith.constant true
%c16 = arith.constant 16 : index
%c32 = arith.constant 32 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c64 = arith.constant 64 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<1x32xf32>
memref.assume_alignment %0, 64 : memref<1x32xf32>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : memref<1x4xf32>
memref.assume_alignment %1, 64 : memref<1x4xf32>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c64) : memref<1x4xi32, strided<[4, 1], offset: 16>>
memref.assume_alignment %2, 64 : memref<1x4xi32, strided<[4, 1], offset: 16>>
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%workgroup_count_x = hal.interface.workgroup.count[0] : index
cf.br ^bb1(%workgroup_id_x : index)
^bb1(%3: index): // 2 preds: ^bb0, ^bb24
%4 = arith.cmpi slt, %3, %c1 : index
cf.cond_br %4, ^bb2(%c0, %true, %cst_1, %c0 : index, i1, f32, index), ^bb25
^bb2(%5: index, %6: i1, %7: f32, %8: index): // 2 preds: ^bb1, ^bb23
%9 = arith.cmpi slt, %5, %c32 : index
cf.cond_br %9, ^bb3, ^bb24
^bb3: // pred: ^bb2
%10 = arith.select %6, %c1, %8 : index
cf.cond_br %6, ^bb4, ^bb5(%7 : f32)
^bb4: // pred: ^bb3
%11 = memref.load %0[%c0, %5] : memref<1x32xf32>
memref.store %11, %1[%c0, %c0] : memref<1x4xf32>
memref.store %c0_i32, %2[%c0, %c0] : memref<1x4xi32, strided<[4, 1], offset: 16>>
cf.br ^bb5(%11 : f32)
^bb5(%12: f32): // 2 preds: ^bb3, ^bb4
%13 = vector.load %0[%c0, %5] : memref<1x32xf32>, vector<16xf32>
%14 = vector.broadcast %13 : vector<16xf32> to vector<1x16xf32>
%15 = vector.broadcast %12 : f32 to vector<16xf32>
%16 = arith.cmpf ogt, %13, %15 : vector<16xf32>
%17 = arith.cmpi slt, %10, %c4 : index
%18 = arith.select %17, %cst_0, %16 : vector<16xi1>
%19 = vector.broadcast %18 : vector<16xi1> to vector<1x16xi1>
%20 = vector.reduction <or>, %18, %false : vector<16xi1> into i1
%21 = vector.insertelement %20, %cst[%c0 : index] : vector<1xi1>
%22 = vector.extract %21[0] : i1 from vector<1xi1>
cf.cond_br %22, ^bb6(%c0, %12, %10 : index, f32, index), ^bb23(%12, %10 : f32, index)
^bb6(%23: index, %24: f32, %25: index): // 2 preds: ^bb5, ^bb22
%26 = arith.cmpi slt, %23, %c16 : index
cf.cond_br %26, ^bb7, ^bb23(%24, %25 : f32, index)
^bb7: // pred: ^bb6
%27 = vector.extract %19[0, %23] : i1 from vector<1x16xi1>
%28 = arith.cmpi eq, %27, %true : i1
cf.cond_br %28, ^bb8, ^bb22(%24, %25 : f32, index)
^bb8: // pred: ^bb7
%29 = vector.extract %14[0, %23] : f32 from vector<1x16xf32>
%30 = arith.addi %23, %5 : index
cf.br ^bb9(%c0, %true, %c0 : index, i1, index)
^bb9(%31: index, %32: i1, %33: index): // 2 preds: ^bb8, ^bb12
%34 = arith.cmpi slt, %31, %25 : index
cf.cond_br %34, ^bb10, ^bb13
^bb10: // pred: ^bb9
%35 = arith.cmpi eq, %32, %true : i1
cf.cond_br %35, ^bb11, ^bb12(%32, %33 : i1, index)
^bb11: // pred: ^bb10
%36 = memref.load %1[%c1, %31] : memref<1x4xf32>
%37 = arith.cmpf olt, %36, %29 : f32
%38 = arith.cmpi eq, %37, %true : i1
%39 = arith.cmpi ne, %37, %true : i1
%40 = arith.andi %39, %32 : i1
%41 = arith.select %38, %31, %33 : index
cf.br ^bb12(%40, %41 : i1, index)
^bb12(%42: i1, %43: index): // 2 preds: ^bb10, ^bb11
%44 = arith.addi %31, %c1 : index
cf.br ^bb9(%44, %42, %43 : index, i1, index)
^bb13: // pred: ^bb9
%45 = arith.cmpi eq, %25, %c4 : index
%46 = arith.andi %45, %32 : i1
cf.cond_br %46, ^bb22(%24, %25 : f32, index), ^bb14
^bb14: // pred: ^bb13
cf.cond_br %32, ^bb15, ^bb16
^bb15: // pred: ^bb14
memref.store %29, %1[%c0, %25] : memref<1x4xf32>
%47 = arith.index_cast %30 : index to i32
memref.store %47, %2[%c0, %25] : memref<1x4xi32, strided<[4, 1], offset: 16>>
%48 = arith.addi %25, %c1 : index
cf.br ^bb22(%29, %48 : f32, index)
^bb16: // pred: ^bb14
cf.cond_br %45, ^bb17, ^bb18(%25 : index)
^bb17: // pred: ^bb16
%49 = arith.subi %25, %c1 : index
cf.br ^bb18(%49 : index)
^bb18(%50: index): // 2 preds: ^bb16, ^bb17
%51 = arith.subi %50, %c1 : index
%52 = memref.load %1[%c0, %33] : memref<1x4xf32>
%53 = memref.load %2[%c0, %33] : memref<1x4xi32, strided<[4, 1], offset: 16>>
cf.br ^bb19(%33, %52, %53 : index, f32, i32)
^bb19(%54: index, %55: f32, %56: i32): // 2 preds: ^bb18, ^bb20
%57 = arith.cmpi slt, %54, %51 : index
cf.cond_br %57, ^bb20, ^bb21
^bb20: // pred: ^bb19
%58 = arith.addi %54, %c1 : index
%59 = memref.load %1[%c0, %58] : memref<1x4xf32>
%60 = memref.load %2[%c0, %58] : memref<1x4xi32, strided<[4, 1], offset: 16>>
memref.store %55, %1[%c0, %58] : memref<1x4xf32>
memref.store %56, %2[%c0, %58] : memref<1x4xi32, strided<[4, 1], offset: 16>>
cf.br ^bb19(%58, %59, %60 : index, f32, i32)
^bb21: // pred: ^bb19
%61 = memref.load %1[%c0, %50] : memref<1x4xf32>
%62 = arith.addi %50, %c1 : index
cf.br ^bb22(%61, %62 : f32, index)
^bb22(%63: f32, %64: index): // 4 preds: ^bb7, ^bb13, ^bb15, ^bb21
%65 = arith.addi %23, %c1 : index
cf.br ^bb6(%65, %63, %64 : index, f32, index)
^bb23(%66: f32, %67: index): // 2 preds: ^bb5, ^bb6
%68 = arith.addi %5, %c16 : index
cf.br ^bb2(%68, %false, %66, %67 : index, i1, f32, index)
^bb24: // pred: ^bb2
%69 = arith.addi %3, %workgroup_count_x : index
cf.br ^bb1(%69 : index)
^bb25: // pred: ^bb1
