[Mlir-commits] [mlir] [Linalg][Vectorization] Add support for linalg vectorization of a tensor.extract case (PR #107922)
Andrzej Warzyński
llvmlistbot at llvm.org
Tue Sep 17 06:43:52 PDT 2024
================
@@ -253,6 +253,54 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1, d2) -> (d0 + d1 + d2)>
+func.func @vectorize_nd_tensor_extract_transfer_without_outer_unit_dim(%arg0: tensor<8x128x768xf32>, %arg1 : index) -> tensor<8x1xf32> {
+ %c0 = arith.constant 0 : index
+ %0 = tensor.empty() : tensor<8x1xf32>
+ %1 = linalg.generic {
+ indexing_maps = [#map],
+ iterator_types = ["parallel", "parallel"]
+ } outs(%0 : tensor<8x1xf32>) {
+ ^bb0(%arg5: f32):
+ %2 = linalg.index 0 : index
+ %3 = linalg.index 1 : index
+ %4 = affine.apply #map1(%arg1, %3, %arg1)
+ %extracted = tensor.extract %arg0[%2, %c0, %4] : tensor<8x128x768xf32>
----------------
banach-space wrote:
Hey Nirvedh, apologies for the delay - I was OOO.
> IIUC, this "jumps" by 768 elements on every iteration.
This should've been 768 * 128, right? Regardless, the overall observation holds (i.e. this is a gather load).
> If this was classified a gather I believe the generated IR is correct
Yes, looks fine to me.
> looks like we need a more conservative check here?
https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp#L906-L909
This is slightly more nuanced, but should be easy to fix. First, a bit of context.
In #102321 I've relaxed the vectoriser so that reading into a column tensor (e.g. `tensor<8x1xf32>`) is still considered as a valid candidate for "contiguous" load. Note that:
* the following comment is no longer valid (sorry about that, I will update that once we decide the right approach for this change): https://github.com/llvm/llvm-project/blob/69f3244da76586be393d1e97b01660c6f03d666c/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp#L893-L894
* the [test](https://github.com/llvm/llvm-project/blob/69f3244da76586be393d1e97b01660c6f03d666c/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir#L112-L142) that I added should fail with your change - do you see why it didn't?
As for a fix, I think that what's missing is logic to identify that `%2 = linalg.index 0 : index` (which jumps by 1 on every iteration) is used for calculating a non-trailing index of `%arg0`. Look here: https://github.com/llvm/llvm-project/blob/69f3244da76586be393d1e97b01660c6f03d666c/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp#L897
This condition should read instead:
```cpp
// getNonUnitLoopDim should be easy to implement
auto nonUnitLoopDim = getNonUnitLoopDim(linalgOp);
if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
foundIndexOp = (indexOp.getDim() == nonUnitLoopDim);
return true;
}
```
Hopefully this makes sense, but please let me know if not. This is quite convoluted stuff (mea culpa) and I really appreciate you fixing this 🙏🏻
https://github.com/llvm/llvm-project/pull/107922
More information about the Mlir-commits
mailing list