[Mlir-commits] [mlir] [Linalg][Vectorization] Add support for linalg vectorization of a tensor.extract case (PR #107922)
Nirvedh Meshram
llvmlistbot at llvm.org
Tue Sep 10 15:28:51 PDT 2024
================
@@ -253,6 +253,54 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1, d2) -> (d0 + d1 + d2)>
+func.func @vectorize_nd_tensor_extract_transfer_without_outer_unit_dim(%arg0: tensor<8x128x768xf32>, %arg1 : index) -> tensor<8x1xf32> {
+ %c0 = arith.constant 0 : index
+ %0 = tensor.empty() : tensor<8x1xf32>
+ %1 = linalg.generic {
+ indexing_maps = [#map],
+ iterator_types = ["parallel", "parallel"]
+ } outs(%0 : tensor<8x1xf32>) {
+ ^bb0(%arg5: f32):
+ %2 = linalg.index 0 : index
+ %3 = linalg.index 1 : index
+ %4 = affine.apply #map1(%arg1, %3, %arg1)
+ %extracted = tensor.extract %arg0[%2, %c0, %4] : tensor<8x128x768xf32>
----------------
nirvedhmeshram wrote:
If this was classified a gather I believe the generated IR is correct
```
func.func @vectorize_nd_tensor_extract_transfer_without_outer_unit_dim(%arg0: tensor<8x128x768xf32>, %arg1: index) -> tensor<8x1xf32> {
%cst = arith.constant dense<768> : vector<1x8xindex>
%cst_0 = arith.constant dense<128> : vector<1x8xindex>
%c0 = arith.constant 0 : index
%cst_1 = arith.constant dense<0.000000e+00> : vector<8x1xf32>
%cst_2 = arith.constant dense<true> : vector<8x1xi1>
%cst_3 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
%0 = tensor.empty() : tensor<8x1xf32>
%1 = vector.broadcast %cst_3 : vector<8xindex> to vector<1x8xindex>
%2 = arith.addi %arg1, %arg1 : index
%3 = vector.broadcast %2 : index to vector<1xindex>
%4 = arith.muli %1, %cst_0 : vector<1x8xindex>
%5 = arith.muli %4, %cst : vector<1x8xindex>
%6 = vector.transpose %5, [1, 0] : vector<1x8xindex> to vector<8x1xindex>
%7 = vector.broadcast %3 : vector<1xindex> to vector<8x1xindex>
%8 = arith.addi %7, %6 : vector<8x1xindex>
%9 = vector.gather %arg0[%c0, %c0, %c0] [%8], %cst_2, %cst_1 : tensor<8x128x768xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32>
%10 = vector.transfer_write %9, %0[%c0, %c0] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32>
return %10 : tensor<8x1xf32>
}
```
https://github.com/llvm/llvm-project/pull/107922
More information about the Mlir-commits
mailing list