[Mlir-commits] [mlir] 84d9694 - [mlir][linalg] Fix vectorisation of tensor.extract with dynamic shapes
Andrzej Warzynski
llvmlistbot at llvm.org
Mon Jul 17 10:28:27 PDT 2023
Author: Andrzej Warzynski
Date: 2023-07-17T17:28:17Z
New Revision: 84d96947ef227e408d27c9db0ee08622d0b20996
URL: https://github.com/llvm/llvm-project/commit/84d96947ef227e408d27c9db0ee08622d0b20996
DIFF: https://github.com/llvm/llvm-project/commit/84d96947ef227e408d27c9db0ee08622d0b20996.diff
LOG: [mlir][linalg] Fix vectorisation of tensor.extract with dynamic shapes
The Linalg vectoriser incorrectly recognises the following
`tensor.extract` as contiguous:
```
func.func @example(%in: tensor<123x321xf32>, %arg1: tensor<1x?x8xf32>) -> tensor<1x?x8xf32> {
%c0 = arith.constant 1 : index
%2 = linalg.generic {
indexing_maps = [#map1],
iterator_types = ["parallel", "parallel", "parallel"]
} outs(%arg1 : tensor<1x?x8xf32>)
{
^bb0(%arg3: f32):
%idx_0 = linalg.index 0 : index
%idx_1 = linalg.index 1 : index
%idx = arith.addi %idx_0, %idx_1 : index
%7 = tensor.extract %in[%c0, %idx] : tensor<123x321xf32>
linalg.yield %7 : f32
} -> tensor<1x?x8xf32>
return %2 : tensor<1x?x8xf32>
}
```
However, the following index Op corresponds to the dynamic dimension
in the iteration space:
```
%idx_1 = linalg.index 1 : index
```
The vectoriser should assume that:
* this index Op _is not_ loop invariant,
* the resulting memory access is a gather load
This is what this patch fixes.
Differential Revision: https://reviews.llvm.org/D155373
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 0a77eccefbf38a..84eea84bd10046 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -925,15 +925,21 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
auto targetShape = linalgOp.getStaticLoopRanges();
auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
- // 0. Is this a 0-D vector? If yes then this is a scalar broadcast.
+ // 0.1 Is this a 0-D vector? If yes then this is a scalar broadcast.
if (inputShape.getShape().empty())
return VectorMemoryAccessKind::ScalarBroadcast;
+ // 0.2 In the case of dynamic shapes just bail-out and assume that it's a
+ // gather load.
+ // TODO: Relax this condition.
+ if (linalgOp.hasDynamicShape())
+ return VectorMemoryAccessKind::Gather;
// 1. Assume that it's a gather load when reading _into_:
// * an n-D vector, like`tensor<1x2x4xi32` or`tensor<2x1x4xi32>`, or
// * a 1-D vector with the trailing dim equal 1, e.g. `tensor<1x4x1xi32`.
// TODO: Relax these conditions.
+ // FIXME: This condition assumes non-dynamic sizes.
if ((llvm::count_if(targetShape,
[](int64_t dimSize) { return dimSize > 1; }) != 1) ||
targetShape.back() == 1)
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir
index 0aaa46025c0d4b..4f4e4b92159bc1 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir
@@ -228,3 +228,46 @@ transform.sequence failures(propagate) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
transform.structured.masked_vectorize %0 vector_sizes [3, 3] { vectorize_nd_extract } : !transform.any_op
}
+
+// -----
+
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+func.func @tensor_extract_dynamic_shape(%arg1: tensor<123x321xf32>, %arg2: tensor<1x?x8xf32>) -> tensor<1x?x8xf32> {
+ %c0 = arith.constant 1 : index
+ %c1 = arith.constant 2 : index
+ %2 = linalg.generic {
+ indexing_maps = [#map1],
+ iterator_types = ["parallel", "parallel", "parallel"]
+ } outs(%arg2 : tensor<1x?x8xf32>)
+ {
+ ^bb0(%arg3: f32):
+ %idx_0 = linalg.index 0 : index
+ %idx_1 = linalg.index 1 : index
+ %idx = arith.addi %idx_0, %idx_1 : index
+ %7 = tensor.extract %arg1[%c0, %idx] : tensor<123x321xf32>
+ linalg.yield %7 : f32
+ } -> tensor<1x?x8xf32>
+ return %2 : tensor<1x?x8xf32>
+}
+
+// TODO: Make sure that this is vectorized as "scalar broadcast" when only
+// vectorising the 2nd dimension.
+// CHECK-LABEL: func.func @tensor_extract_dynamic_shape(
+// CHECK-SAME: %[[ARG_1:.*]]: tensor<123x321xf32>,
+// CHECK-SAME: %[[ARG_2:.*]]: tensor<1x?x8xf32>) -> tensor<1x?x8xf32> {
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[C1_1:.*]] = arith.constant 1 : index
+// CHECK: %[[C1_2:.*]] = arith.constant 1 : index
+// CHECK: %[[DIM:.*]] = tensor.dim %[[ARG_2]], %[[C1_2]] : tensor<1x?x8xf32>
+// CHECK: %[[C8:.*]] = arith.constant 8 : index
+// CHECK: %[[MASK:.*]] = vector.create_mask %[[C1_1]], %[[DIM]], %[[C8]] : vector<1x3x8xi1>
+// CHECK: %[[MASK_2:.*]] = arith.constant dense<true> : vector<1x3x8xi1>
+// CHECK: %[[FALLTHROUGH:.*]] = arith.constant dense<0.000000e+00> : vector<1x3x8xf32>
+// CHECK: %[[C0_1:.*]] = arith.constant 0 : index
+// CHECK: vector.mask %[[MASK]] { vector.gather %[[ARG_1]][%[[C0_1]], %[[C0_1]]] [%{{.*}}], %[[MASK_2]], %[[FALLTHROUGH]] : tensor<123x321xf32>, vector<1x3x8xindex>, vector<1x3x8xi1>, vector<1x3x8xf32> into vector<1x3x8xf32> } : vector<1x3x8xi1> -> vector<1x3x8xf32>
+
+transform.sequence failures(propagate) {
+ ^bb1(%arg1: !transform.any_op):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.masked_vectorize %0 vector_sizes [1, 3, 8] { vectorize_nd_extract } : !transform.any_op
+}
More information about the Mlir-commits
mailing list