<table border="1" cellspacing="0" cellpadding="8">
<tr>
<th>Issue</th>
<td>
<a href=https://github.com/llvm/llvm-project/issues/150185>150185</a>
</td>
</tr>
<tr>
<th>Summary</th>
<td>
[mlir] How to simplify this overly complex size calculation after tiling?
</td>
</tr>
<tr>
<th>Labels</th>
<td>
mlir
</td>
</tr>
<tr>
<th>Assignees</th>
<td>
</td>
</tr>
<tr>
<th>Reporter</th>
<td>
banach-space
</td>
</tr>
</table>
<pre>
Hi folks,
_This is based on https://github.com/iree-org/iree/issues/21393 that was originally posted by @egebeysel , thanks! I've re-written it using "pure" MLIR (i.e. to not require IREE)._
**REPRO**
```mlir
func.func @unpack(%arg0: tensor<512x?x8x?xf32>, %arg1: tensor<4096x4096xf32>, %arg2: tensor<512x?x8x?xf32>) -> tensor<4096x4096xf32> {
%c8 = arith.constant 8 : index
%vscale = vector.vscale
%c8_vscale = arith.muli %vscale, %c8 : index
%0 = tensor.empty() : tensor<4096x4096xf32>
%unpack = linalg.unpack %arg2 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, %c8_vscale] into %0 : tensor<512x?x8x?xf32> -> tensor<4096x4096xf32>
return %unpack : tensor<4096x4096xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module : !transform.any_op {transform.readonly}) {
%unpack = transform.structured.match ops{["linalg.unpack"]} in %module
: (!transform.any_op) -> !transform.any_op
%tiled_unpack, %loops:2 = transform.structured.tile_using_for %unpack tile_sizes [8, [8]]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
}
}
```
After tiling (`bin/mlir-opt --transform-interpreter unpack.mlir -cse -test-transform-dialect-erase-schedule`):
```mlir
#map = affine_map<(d0)[s0] -> (-d0 + 4096, s0)>
#map1 = affine_map<(d0) -> (d0 floordiv 8)>
#map2 = affine_map<(d0)[s0] -> (d0 floordiv s0)>
#map3 = affine_map<(d0)[s0] -> (d0 mod s0)>
#map4 = affine_map<(d0, d1)[s0] -> ((d0 + d1 - 1) floordiv s0 - d0 floordiv s0 + 1)>
// ...
%1 = scf.for %arg3 = %c0 to %c4096 step %c8 iter_args(%arg4 = %0) -> (tensor<4096x4096xf32>) {
%2 = scf.for %arg5 = %c0 to %c4096 step %c8_vscale iter_args(%arg6 = %arg4) -> (tensor<4096x4096xf32>) {
%3 = affine.min #map(%arg5)[%c8_vscale]
%4 = affine.apply #map1(%arg3)
%5 = affine.apply #map2(%arg5)[%c8_vscale]
%7 = affine.apply #map4(%arg5, %3)[%c8_vscale]
%extracted_slice = tensor.extract_slice %arg2[%4, %5, 0, 0] [1, %7, 8, %c8_vscale] [1, 1, 1, 1] : tensor<512x?x8x?xf32> to tensor<1x?x8x?xf32>
%9 = tensor.empty(%8) : tensor<8x?xf32>
%unpack = linalg.unpack %extracted_slice outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, %c8_vscale] into %9 : tensor<1x?x8x?xf32> -> tensor<8x?xf32>
// ...
```
**ISSUE**
This expression in the output is overly complex:
```mlir
%7 = affine.apply #map4(%arg5, %3)[%c8_vscale]
// ...
%extracted_slice = tensor.extract_slice %arg2[%4, %5, 0, 0] [1, %7, 8, %c8_vscale] [1, 1, 1, 1] : tensor<512x?x8x?xf32> to tensor<1x?x8x?xf32>
```
Specifically, `%7` should be trivially `1` (see below). As a result, we get the unpack source shape 1x?x8x? instead of 1x1x8x?, which makes `tensor.unpack` trickier to vectorise.
**WHY SHOULD %7 BE 1?**
For ease of interpretation:
* `%3`: tile/step size or remaining size (last iteration)
* `%4`: index of the loop over M (tile/step size = `8`)
* `%5`: index of the loop over N (tile/step size = `8 * vscale`)
* `%7`: from how many %5s are we reading from -> only != 1 when inner tile sizes are **not aligned with tile sizes**, which is not the case here (`8 * vscale` for both here).
Specifically, given that:`%7 = ((%arg5 + %3 - 1) floordiv %c8_vscale - %arg5 floordiv %c8_vscale + 1)`, and:
* `%3` <= `8 * vscale` --> `%3 - 1 < 8 * vscale`
* `(%arg5 + %3 - 1) / %c8_vscale` == `%arg5 / %c8_vscale`
* `((%arg5 + %3 - 1) / %c8_vscale - %arg5 / %c8_vscale` == `0`
we can safely conclude that:`%7 == 1`.
</pre>
<img width="1" height="1" alt="" src="http://email.email.llvm.org/o/eJzcWN2Po7oV_2uclyMQmJCQhzxkJjPale7trXZ7VfUpcuAQ3AXM2mYy6V9fHZsk5Gt2t6r6UCmCyD7-nU-fD4QxctciLln6xNL1RPS2Unq5Fa3Iq8B0IsfJVhWH5ScJpaq_GcafWbRi0Wrzt0oakAa2wmABqoXK2s6wZMX4K-OvO2mrfhvmqmH8VWrEQOnd8JdexvRoGH_lcbJIwFbCwl4YUFruZCvq-gCdMhYL2B6ATSPc4RYPBmtg_JnIW5Ilhs-Mz98QNAZ7La3FFqSF3sh2B4zzrtfIOIfff_v8BRjPZIghWAWtsqDxey81wucvLy-ML8KN12v4cfp9efnrlz_8X1qbRf7X1FKzaFX2bR7Sg-Tr207k3xjPGE-F3kUsWYHF1ijNkuc05u8seX3P3LNMOEteSA1PG1_QTqPF7N09run4T2AuIGDJy2M0YPMnFq2AIPMMWLIGoaWtwly1xorWAi2uQLYFvh8J30wuanTEb5hbpUO_cgbajEg8XtPX8nx2UCK_Ax65Q17gEJvOHpwRF_CxVYbT3uwOoqaw2YXHFW8yUL1FvSlkYzYd6sZRsvQpIolilq5Btu2JQJlH-1bWeNrMTuoMens6q476_MBNP_CR002j7XV7oeLH5mDztQ_dRhV9jSCs1XLbWxJ7_mS1aE2pdBPupa02rWiw2Bj83mObI5uvT3FxJrykoRjfbE67m0bI1kf7wI_kYzw-nxftYaO6S-YaRaHa-kDCkos900s_nqmN1X1ue41F2AibV6A6Q0fSJ8b5hbsZ5yxdkx7S2czL5LCdWNkdyU6X5d6esySAizHyfbE5MnKer5VyqY4_lphObVwi2pRKj3R0G0b-ixxzjCV6p2v6ea7wc3Lf336-q9Gj5cVZ1_PmQWJd-Dvmw8o_jwnQH1mVFjWp45NtxmbRloLilfJjoDoLQXCCDGRrUXca6Yy3REh0EOQGIbBo7Ii4kKLG3AaohcHA5BU6f84ikje5l4wZTxrR-QRUlrLFTSM6ljwznhXuVPpkIrqnR9MFRQSMPwHdIjKO8dgvJ6z4MdgJpIigrJXShXyD7Oo8_wVhxjg3giS_BtSo4hZj-hDjGYr4HpIHIwsVMQRANGMZIYBLoR1pfGYbDX0AhGF4vkveqCYvw-FSCL3z-lE-jcDn0JycAsZiN1QNSUlc6J05Fdjp8dCFPz6ooqdsA4Mo_I4o6Q9FOVa6W4lmx7Mk3X8klBNr7O2wcemMHHhik3pnXZWfS4yxt0PRdfVhQIlPMIm_-aND6aND_Jd4zx_BTMcwLosmPwGH71aL3FIlqmWOF82C3zpuDD2Sw5sOHBynyD9SV7vjYWdO77uF_Eg0ftDyD6u6VWeC-LY3u1Brcb_rSW87n8cYH_U-12b7n7dBV1rc2uOqCbpS8yJ3XNUd35B__vr1z5dzbx6t3DyC751GY6RqqROwldO86y0NKuoNdX2AXDVdje-Pysh_JYAvxf__iuKxN752mMtS5jSvOSmoQqdzNovAVKqvC9giWC3fpJvo2CyKaY_xzCDCFmu1p9ELVgYEaDR9bQlmj7BD69w3hLRRvc4RTCU6hJFYIFtjURSgSojfY7_oICqZV9CIbxSvs2gw9tDEzSISKv8mqX9Rw1AjDYYwjrC_f_oHfP30x5-_rX1ee3qB2KEPMfeqNKAwSLxP3Y2wUrVDbDmgwSYJvcj4kqahV1dQqAkEpUEjddPURbkVxrNaGOtKjIfzTdoJazpguTmKuJOhqCV1IQ6_u6Jzzcff2lmUDV3UGDD9GPAvdwGPeEA4Q-DdQs8H6FKrBiq1h0a0BxfUBoRGcjUNBaS8I3FZgUYE6laJRQz7iqZ6yj_OeuB7ZzrsXUHDvKjlrsUCaLoZUQ2-OsWDNG70J_Vy8lyFDiW71QOoKdgqWzkaClLv0euI38k3bN3nC3K6V3loBLJzvnDNkSvu153URVMRnNqQ-_unFmvmsoFoiweRBtTf3fUPBL4t8aQkDtHCtRdHgA91oAx3kXYc3_XA93jKkd3SXTL4aR4jE33M_lwu9uTrFowo0WX_Nq_7Au-4zEUbm0UhTIplUiyShZjgMp6nCU8XPE4n1TLO4jhLihh5GuciWiyyMuJlMi3yWV5s09lELnnE02jOkyjjszQNi0WZbdO8LBfbOI3LhE0juu11WNdvTaj0buK-hC3jNIqzdFKLLdZm6UdcV5PcZDvRS6IPtv3OsGlUS2PNGcFKW7sveO5AuoZPak95zcimq2V5AFvdlD9_h3NR533tkgyI0TjHktdJr-vlB5_0iPvwCjqt_om5HX_WG_R5W_J_BwAA__8Bqtnq">