[Mlir-commits] [mlir] [mlir][linalg] Enable Vectorization of 0-D tensor.extract (PR #119079)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Sat Dec 7 08:33:30 PST 2024
https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/119079
This patch removes an assert in `vectorizeTensorExtract` that was blocking
the vectorization of 0-D tensor.extract operations, e.g.:
```mlir
%1 = tensor.extract %src[] : tensor<f32>
```
As demonstrated by the included tests, this case is already effectively
supported.
**Context**
The removed assert was introduced in #109580 as a guard, pending proper support
and testing for 0-D tensors. This PR addresses that previously undocumented
TODO. Apologies for the oversight!
**Updates and Tests**
* Revised the existing test `@negative_no_loop` to ensure the
`vectorize_nd_extract` attribute is included, allowing the vectorizer
to process it. The test was renamed and variables updated for clarity.
* Added a new test `@extract_scalar_from_0d_into_1d` to cover "mixed"
0-D/1-D tensor extraction, e.g.:
```mlir
%res = linalg.generic {
indexing_maps = [#map],
iterator_types = ["parallel"]
} outs(%init : tensor<1xf32>) {
^bb0(%in: f32):
%1 = tensor.extract %src[] : tensor<f32>
linalg.yield %1 : f32
} -> tensor<1xf32>
return %res : tensor<1xf32>
```
**Additional updates**
I also took the liberty and improved test coverage for 0-D tensor in the
vectorizer tests:
* Added a specific test for "0D linalg.generic" in
"vectorization-with-patterns.mlir".
* Renamed several tests in "vectorization-with-patterns.mlir" to clarify
that the 0-D case is now covered.
>From 1e6c8b3085b397fedfdba36e568230ab40bde68f Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Sat, 7 Dec 2024 16:11:55 +0000
Subject: [PATCH] [mlir][linalg] Enable Vectorization of 0-D tensor.extract
This patch removes an assert in `vectorizeTensorExtract` that was blocking
the vectorization of 0-D tensor.extract operations, e.g.:
```mlir
%1 = tensor.extract %src[] : tensor<f32>
```
As demonstrated by the included tests, this case is already effectively
supported.
**Context**
The removed assert was introduced in #109580 as a guard, pending proper support
and testing for 0-D tensors. This PR addresses that previously undocumented
TODO. Apologies for the oversight!
**Updates and Tests**
* Revised the existing test `@negative_no_loop` to ensure the
`vectorize_nd_extract` attribute is included, allowing the vectorizer
to process it. The test was renamed and variables updated for clarity.
* Added a new test `@extract_scalar_from_0d_into_1d` to cover "mixed"
0-D/1-D tensor extraction, e.g.:
```mlir
%res = linalg.generic {
indexing_maps = [#map],
iterator_types = ["parallel"]
} outs(%init : tensor<1xf32>) {
^bb0(%in: f32):
%1 = tensor.extract %src[] : tensor<f32>
linalg.yield %1 : f32
} -> tensor<1xf32>
return %res : tensor<1xf32>
```
**Additional updates**
I also took the liberty and improved test coverage for 0-D tensor in the
vectorizer tests:
* Added a specific test for "0D linalg.generic" in
"vectorization-with-patterns.mlir".
* Renamed several tests in "vectorization-with-patterns.mlir" to clarify
that the 0-D case is now covered.
---
.../Linalg/Transforms/Vectorization.cpp | 5 --
.../Linalg/vectorization-with-patterns.mlir | 48 ++++++++++++++-
.../Linalg/vectorize-tensor-extract.mlir | 58 +++++++++++++++----
3 files changed, 93 insertions(+), 18 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index e5c96b52acee23..863f2280e46ce6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1115,11 +1115,6 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
// b. contiguous loads.
// Both cases use vector.transfer_read.
- assert(llvm::count_if(resultType.getShape(),
- [](uint64_t dim) { return dim != 1; }) &&
- "Contiguous loads and scalar loads + broadcast only support 1-D "
- "vectors ATM!");
-
// Collect indices for `vector.transfer_read`. At this point, the indices will
// either be scalars or would have been broadcast to vectors matching the
// result type. For indices that are vectors, there are two options:
diff --git a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
index 0c996bed996d3c..b688a677500c22 100644
--- a/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-with-patterns.mlir
@@ -122,6 +122,48 @@ module attributes {transform.with_named_sequence} {
// -----
+#map = affine_map<() -> ()>
+
+// CHECK-LABEL: func.func @generic_0d(
+// CHECK-SAME: %[[ARG_0:.*]]: tensor<f32>, %[[ARG_1:.*]]: tensor<f32>, %[[ARG_2:.*]]: tensor<f32>)
+func.func @generic_0d(%arg0: tensor<f32>, %arg1: tensor<f32>,
+ %arg2: tensor<f32>) -> tensor<f32> {
+// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[READ_0:.*]] = vector.transfer_read %[[ARG_0]][], %[[PAD]] : tensor<f32>, vector<f32>
+// CHECK: %[[ARG_0_AS_SCALAR:.*]] = vector.extract %[[READ_0]][] : f32 from vector<f32>
+// CHECK: %[[READ_1:.*]] = vector.transfer_read %[[ARG_1]][], %[[PAD]] : tensor<f32>, vector<f32>
+// CHECK: %[[ARG_1_AS_SCALAR:.*]] = vector.extract %[[READ_1]][] : f32 from vector<f32>
+// CHECK: %[[READ_2:.*]] = vector.transfer_read %[[ARG_2]][], %[[PAD]] : tensor<f32>, vector<f32>
+// CHECK: %[[ARG_2_AS_SCALAR:.*]] = vector.extract %[[READ_2]][] : f32 from vector<f32>
+// CHECK: %[[MULF:.*]] = arith.mulf %[[ARG_0_AS_SCALAR]], %[[ARG_1_AS_SCALAR]] : f32
+// CHECK: %[[ADDF:.*]] = arith.addf %[[ARG_2_AS_SCALAR]], %[[MULF]] : f32
+// CHECK: %[[ADDF_BCAST:.*]] = vector.broadcast %[[ADDF]] : f32 to vector<f32>
+// CHECK: vector.transfer_write %[[ADDF_BCAST]], %[[ARG_2]][] : vector<f32>, tensor<f32>
+ %res = linalg.generic {
+ indexing_maps = [#map, #map, #map],
+ iterator_types = []
+ } ins(%arg0, %arg1 : tensor<f32>, tensor<f32>)
+ outs(%arg2 : tensor<f32>) {
+ ^bb(%a: f32, %b: f32, %c: f32) :
+ %d = arith.mulf %a, %b: f32
+ %e = arith.addf %c, %d: f32
+ linalg.yield %e : f32
+ } -> tensor<f32>
+
+ return %res : tensor<f32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 { disable_multi_reduction_to_contract_patterns, disable_transfer_permutation_map_lowering_patterns } : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
#matmul_transpose_out_trait = {
indexing_maps = [
affine_map<(m, n, k) -> (m, k)>,
@@ -372,7 +414,7 @@ module attributes {transform.with_named_sequence} {
// -----
// CHECK-LABEL: func @test_vectorize_fill
-func.func @test_vectorize_fill_scalar(%A : memref<f32>, %arg0 : f32) {
+func.func @test_vectorize_fill_0d(%A : memref<f32>, %arg0 : f32) {
// CHECK-SAME: (%[[M:.*]]: memref<f32>, %[[val:.*]]: f32)
// CHECK: %[[VEC:.*]] = vector.broadcast %[[val]] : f32 to vector<f32>
// CHECK: vector.transfer_write %[[VEC]], %[[M]][] : vector<f32>, memref<f32>
@@ -410,8 +452,8 @@ module attributes {transform.with_named_sequence} {
// -----
-// CHECK-LABEL: func @test_vectorize_copy_scalar
-func.func @test_vectorize_copy_scalar(%A : memref<f32>, %B : memref<f32>) {
+// CHECK-LABEL: func @test_vectorize_copy_0d
+func.func @test_vectorize_copy_0d(%A : memref<f32>, %B : memref<f32>) {
// CHECK-SAME: (%[[A:.*]]: memref<f32>, %[[B:.*]]: memref<f32>)
// CHECK: %[[V:.*]] = vector.transfer_read %[[A]][]{{.*}} : memref<f32>, vector<f32>
// CHECK: %[[val:.*]] = vector.extract %[[V]][] : f32 from vector<f32>
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index 1a93d1cd9b7880..775ceed31be04a 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -39,29 +39,67 @@ module attributes {transform.with_named_sequence} {
// -----
#map = affine_map<() -> ()>
-func.func @negative_no_loops(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
- %1 = linalg.generic {
+func.func @extract_scalar_from_0d_into_0d(%src: tensor<f32>, %init: tensor<f32>) -> tensor<f32> {
+ %res = linalg.generic {
indexing_maps = [#map],
iterator_types = []
- } outs(%arg1 : tensor<f32>) {
- ^bb0(%arg4: f32):
- %2 = tensor.extract %arg0[] : tensor<f32>
- linalg.yield %2 : f32
+ } outs(%init : tensor<f32>) {
+ ^bb0(%in: f32):
+ %1 = tensor.extract %src[] : tensor<f32>
+ linalg.yield %1 : f32
} -> tensor<f32>
- return %1 : tensor<f32>
+
+ return %res : tensor<f32>
}
-// CHECK-LABEL: func.func @negative_no_loops
-// CHECK: tensor.extract
+
+// CHECK-LABEL: func.func @extract_scalar_from_0d_into_0d(
+// CHECK-SAME: %[[SRC:.*]]: tensor<f32>,
+// CHECK-SAME: %[[INIT:.*]]: tensor<f32>) -> tensor<f32> {
+// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]][], %[[PAD]] : tensor<f32>, vector<f32>
+// CHECK: vector.transfer_write %[[READ]], %[[INIT]][] : vector<f32>, tensor<f32>
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
- %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op
transform.yield
}
}
+// -----
+
+#map = affine_map<(n) -> (n)>
+func.func @extract_scalar_from_0d_into_1d(%src: tensor<f32>, %init: tensor<1xf32>) -> tensor<1xf32> {
+ %res = linalg.generic {
+ indexing_maps = [#map],
+ iterator_types = ["parallel"]
+ } outs(%init : tensor<1xf32>) {
+ ^bb0(%in: f32):
+ %1 = tensor.extract %src[] : tensor<f32>
+ linalg.yield %1 : f32
+ } -> tensor<1xf32>
+
+ return %res : tensor<1xf32>
+}
+// CHECK-LABEL: func.func @extract_scalar_from_0d_into_1d(
+// CHECK-SAME: %[[SRC:.*]]: tensor<f32>,
+// CHECK-SAME: %[[INIT:.*]]: tensor<1xf32>) -> tensor<1xf32> {
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[PAD:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[READ:.*]] = vector.transfer_read %[[SRC]][], %[[PAD]] : tensor<f32>, vector<f32>
+// CHECK: %[[READ_BCAST:.*]] = vector.broadcast %[[READ]] : vector<f32> to vector<1xf32>
+// CHECK: vector.transfer_write %[[READ_BCAST]], %[[INIT]][%[[C0]]] {in_bounds = [true]} : vector<1xf32>, tensor<1xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
// -----
More information about the Mlir-commits
mailing list