[Mlir-commits] [mlir] [mlir][LinAlg] Vectorize reverse-like ops using vector.gather ops. (PR #83205)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 27 15:15:58 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Han-Chung Wang (hanhanW)
<details>
<summary>Changes</summary>
The reverse op is treated as a VectorMemoryAccessKind::Contiguous load. It is contiguous slice, but we'll need to compute indices differently and apply a reverse at vector level. It takes non-trivial efforts for the approach. The revision flips the case to use vector.gather. Otherwise there are functionality issues. E.g., the below example loaded `2, 3, 4` (which is a bug), but what we want is `2, 1, 0`.
Before vectorization:
```mlir
func.func @<!-- -->vectorize_reverse_like_tensor_extract(%arg0: tensor<1x2x3xf32>, %arg1: tensor<1x1x3xf32>, %arg2: index) -> tensor<1x1x3xf32> {
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%0 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel", "parallel"]} outs(%arg1 : tensor<1x1x3xf32>) {
^bb0(%out: f32):
%1 = linalg.index 1 : index
%2 = linalg.index 0 : index
%3 = affine.apply #map1(%1, %2, %arg2)
%4 = linalg.index 2 : index
%5 = arith.subi %c2, %4 : index
%extracted = tensor.extract %arg0[%c0, %3, %5] : tensor<1x2x3xf32>
linalg.yield %extracted : f32
} -> tensor<1x1x3xf32>
return %0 : tensor<1x1x3xf32>
}
```
Partial IR after vectorization:
```
%5 = vector.constant_mask [1, 1, 3] : vector<1x1x4xi1>
%6 = vector.broadcast %arg0 : index to vector<1x1x4xindex>
%7 = vector.shape_cast %6 : vector<1x1x4xindex> to vector<4xindex>
%8 = vector.extractelement %7[%c0_i32 : i32] : vector<4xindex>
%9 = vector.transfer_read %3[%c0, %8, %c2], %cst, %5 {in_bounds = [true, true, true]} : tensor<1x2x3xf32>, vector<1x1x4xf32>
```
---
Full diff: https://github.com/llvm/llvm-project/pull/83205.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+1-2)
- (modified) mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir (+46)
``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index ac043e87223dfe..1e703dacfd0c75 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -891,8 +891,7 @@ static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
// Conservatively reject Ops that could lead to indices with stride other
// than 1.
- if (!isa<arith::AddIOp, arith::SubIOp, arith::ConstantOp, linalg::IndexOp>(
- ancestor))
+ if (!isa<arith::AddIOp, arith::ConstantOp, linalg::IndexOp>(ancestor))
return false;
bool result = false;
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index 96953c234a0873..9832b312c32439 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -542,6 +542,52 @@ func.func @vectorize_0d_tensor_extract(%arg0: tensor<f32>, %arg2: tensor<1x1x3xf
// CHECK: %[[EXTRACT:.*]] = tensor.extract %[[ARG_0]][] : tensor<f32>
// CHECK: vector.broadcast %[[EXTRACT]] : f32 to vector<1x1x3xf32>
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+#map1 = affine_map<(d0, d1, d2) -> (d0 + d1 + d2)>
+func.func @vectorize_reverse_like_tensor_extract(%arg0: tensor<1x2x3xf32>, %arg1: tensor<1x1x3xf32>, %arg2: index) -> tensor<1x1x3xf32> {
+ %c1 = arith.constant 1 : index
+ %c0 = arith.constant 0 : index
+ %c2 = arith.constant 2 : index
+ %0 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel", "parallel", "parallel"]} outs(%arg1 : tensor<1x1x3xf32>) {
+ ^bb0(%out: f32):
+ %1 = linalg.index 1 : index
+ %2 = linalg.index 0 : index
+ %3 = affine.apply #map1(%1, %2, %arg2)
+ %4 = linalg.index 2 : index
+ %5 = arith.subi %c2, %4 : index
+ %extracted = tensor.extract %arg0[%c0, %3, %5] : tensor<1x2x3xf32>
+ linalg.yield %extracted : f32
+ } -> tensor<1x1x3xf32>
+ return %0 : tensor<1x1x3xf32>
+}
+// CHECK-LABEL: func.func @vectorize_reverse_like_tensor_extract
+// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]
+// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]
+// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]*]]
+// CHECK-DAG: %[[CST:.+]] = arith.constant dense<3> : vector<1x1x3xindex>
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[MASK:.*]] = arith.constant dense<true> : vector<1x1x3xi1>
+// CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x3xf32>
+// CHECK-DAG: %[[INIT_IDX:.+]] = arith.constant dense<[2, 1, 0]> : vector<3xindex>
+// CHECK: %[[T0:.+]] = vector.broadcast %[[ARG2]] : index to vector<1x1x3xindex>
+// CHECK: %[[T1:.+]] = arith.muli %[[T0]], %[[CST]] : vector<1x1x3xindex>
+// CHECK: %[[T2:.+]] = vector.broadcast %[[INIT_IDX]]
+// CHECK: %[[T3:.+]] = arith.addi %[[T2]], %[[T1]]
+// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[T3]]], %[[MASK]], %[[PASSTHRU]]
+// CHECK: vector.transfer_write %[[GATHER]]
+
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
``````````
</details>
https://github.com/llvm/llvm-project/pull/83205
More information about the Mlir-commits
mailing list