[PATCH] D153421: [mlir][Linalg] Implement the tiling interface for softmax
Quentin Colombet via Phabricator via llvm-commits
llvm-commits at lists.llvm.org
Wed Jun 21 08:23:16 PDT 2023
qcolombet added inline comments.
================
Comment at: mlir/test/Dialect/Linalg/tile-softmax.mlir:24
+// CHECK: %[[VAL_15:.*]] = tensor.extract_slice %[[VAL_12]]{{\[}}%[[VAL_8]], %[[VAL_11]], 0] [2, %[[VAL_13]], 256] [1, 1, 1] : tensor<16x64x256xf32> to tensor<2x?x256xf32>
+// CHECK: %[[VAL_16:.*]] = linalg.softmax dimension(1) ins(%[[VAL_14]] : tensor<2x?x256xf32>) outs(%[[VAL_15]] : tensor<2x?x256xf32>) -> tensor<2x?x256xf32>
+// CHECK: %[[VAL_17:.*]] = tensor.insert_slice %[[VAL_16]] into %[[VAL_12]]{{\[}}%[[VAL_8]], %[[VAL_11]], 0] [2, %[[VAL_13]], 256] [1, 1, 1] : tensor<2x?x256xf32> into tensor<16x64x256xf32>
----------------
rengolin wrote:
> The semantics of this is to take the max of each tile, which is not always the same as the softmax of the original dimension.
>
> If the tile is a whole `head`, this may be what you want. If not, you'll get different results.
>
> I'm not sure how to restrict this in a meaningful way, or perhaps this is up to the compiler to "do the right thing".
>
> At the very least, this should be documented somewhere, perhaps in the op description?
Good point.
In this example I went with the easiest transformation to apply, but you're right this may not be correct.
Documenting this somewhere is sensible, but it may make more sense to reject invalid tiling.
At the same time maybe the compiler/ir author knows something we don't and we should let this go through...
@ftynse, @nicolasvasilache do you have a recommendation on what we should do here?
Technically `getTiledImplementation` can return a failure, but I don't know if that's doable/desirable to do that check here.
================
Comment at: mlir/test/Dialect/Linalg/tile-softmax.mlir:39
+
+transform.sequence failures(propagate) {
+ ^bb0(%arg1: !transform.any_op):
----------------
rengolin wrote:
> I'm assuming this also works with the tile-and-fuse pass. Could there be a test for that, too?
Let me give a try!
Repository:
rG LLVM Github Monorepo
CHANGES SINCE LAST ACTION
https://reviews.llvm.org/D153421/new/
https://reviews.llvm.org/D153421
More information about the llvm-commits
mailing list