[Mlir-commits] [mlir] 84d9694 - [mlir][linalg] Fix vectorisation of tensor.extract with dynamic shapes

Andrzej Warzynski llvmlistbot at llvm.org
Mon Jul 17 10:28:27 PDT 2023


Author: Andrzej Warzynski
Date: 2023-07-17T17:28:17Z
New Revision: 84d96947ef227e408d27c9db0ee08622d0b20996

URL: https://github.com/llvm/llvm-project/commit/84d96947ef227e408d27c9db0ee08622d0b20996
DIFF: https://github.com/llvm/llvm-project/commit/84d96947ef227e408d27c9db0ee08622d0b20996.diff

LOG: [mlir][linalg] Fix vectorisation of tensor.extract with dynamic shapes

The Linalg vectoriser incorrectly recognises the following
`tensor.extract` as contiguous:
```
func.func @example(%in: tensor<123x321xf32>, %arg1: tensor<1x?x8xf32>) -> tensor<1x?x8xf32> {
  %c0 = arith.constant 1 : index
  %2 = linalg.generic {
    indexing_maps = [#map1],
    iterator_types = ["parallel", "parallel", "parallel"]
  } outs(%arg1 : tensor<1x?x8xf32>)
  {
  ^bb0(%arg3: f32):
    %idx_0 = linalg.index 0 : index
    %idx_1 = linalg.index 1 : index
    %idx = arith.addi %idx_0, %idx_1 : index
    %7 = tensor.extract %in[%c0, %idx] : tensor<123x321xf32>
    linalg.yield %7 : f32
  } -> tensor<1x?x8xf32>
  return %2 : tensor<1x?x8xf32>
}
```

However, the following index Op corresponds to the dynamic dimension
in the iteration space:
```
    %idx_1 = linalg.index 1 : index
```
The vectoriser should assume that:
  * this index Op _is not_ loop invariant,
  * the resulting memory access is a gather load
This is what this patch fixes.

Differential Revision: https://reviews.llvm.org/D155373

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 0a77eccefbf38a..84eea84bd10046 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -925,15 +925,21 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
   auto targetShape = linalgOp.getStaticLoopRanges();
   auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
 
-  // 0. Is this a 0-D vector? If yes then this is a scalar broadcast.
+  // 0.1 Is this a 0-D vector? If yes then this is a scalar broadcast.
   if (inputShape.getShape().empty())
     return VectorMemoryAccessKind::ScalarBroadcast;
 
+  // 0.2 In the case of dynamic shapes just bail-out and assume that it's a
+  // gather load.
+  // TODO: Relax this condition.
+  if (linalgOp.hasDynamicShape())
+    return VectorMemoryAccessKind::Gather;
 
   // 1. Assume that it's a gather load when reading _into_:
   //    * an n-D vector, like`tensor<1x2x4xi32` or`tensor<2x1x4xi32>`, or
   //    * a 1-D vector with the trailing dim equal 1, e.g. `tensor<1x4x1xi32`.
   // TODO: Relax these conditions.
+  // FIXME: This condition assumes non-dynamic sizes.
   if ((llvm::count_if(targetShape,
                       [](int64_t dimSize) { return dimSize > 1; }) != 1) ||
       targetShape.back() == 1)

diff  --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir
index 0aaa46025c0d4b..4f4e4b92159bc1 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir
@@ -228,3 +228,46 @@ transform.sequence failures(propagate) {
    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
    transform.structured.masked_vectorize %0 vector_sizes [3, 3] { vectorize_nd_extract } : !transform.any_op
  }
+
+// -----
+
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+func.func @tensor_extract_dynamic_shape(%arg1: tensor<123x321xf32>, %arg2: tensor<1x?x8xf32>) -> tensor<1x?x8xf32> {
+  %c0 = arith.constant 1 : index
+  %c1 = arith.constant 2 : index
+  %2 = linalg.generic {
+    indexing_maps = [#map1],
+    iterator_types = ["parallel", "parallel", "parallel"]
+  } outs(%arg2 : tensor<1x?x8xf32>)
+  {
+  ^bb0(%arg3: f32):
+    %idx_0 = linalg.index 0 : index
+    %idx_1 = linalg.index 1 : index
+    %idx = arith.addi %idx_0, %idx_1 : index
+    %7 = tensor.extract %arg1[%c0, %idx] : tensor<123x321xf32>
+    linalg.yield %7 : f32
+  } -> tensor<1x?x8xf32>
+  return %2 : tensor<1x?x8xf32>
+} 
+
+// TODO: Make sure that this is vectorized as "scalar broadcast" when only
+// vectorising the 2nd dimension.
+// CHECK-LABEL:   func.func @tensor_extract_dynamic_shape(
+// CHECK-SAME:      %[[ARG_1:.*]]: tensor<123x321xf32>,
+// CHECK-SAME:      %[[ARG_2:.*]]: tensor<1x?x8xf32>) -> tensor<1x?x8xf32> {
+// CHECK:           %[[C2:.*]] = arith.constant 2 : index
+// CHECK:           %[[C1_1:.*]] = arith.constant 1 : index
+// CHECK:           %[[C1_2:.*]] = arith.constant 1 : index
+// CHECK:           %[[DIM:.*]] = tensor.dim %[[ARG_2]], %[[C1_2]] : tensor<1x?x8xf32>
+// CHECK:           %[[C8:.*]] = arith.constant 8 : index
+// CHECK:           %[[MASK:.*]] = vector.create_mask %[[C1_1]], %[[DIM]], %[[C8]] : vector<1x3x8xi1>
+// CHECK:           %[[MASK_2:.*]] = arith.constant dense<true> : vector<1x3x8xi1>
+// CHECK:           %[[FALLTHROUGH:.*]] = arith.constant dense<0.000000e+00> : vector<1x3x8xf32>
+// CHECK:           %[[C0_1:.*]] = arith.constant 0 : index
+// CHECK:           vector.mask %[[MASK]] { vector.gather %[[ARG_1]][%[[C0_1]], %[[C0_1]]] [%{{.*}}], %[[MASK_2]], %[[FALLTHROUGH]] : tensor<123x321xf32>, vector<1x3x8xindex>, vector<1x3x8xi1>, vector<1x3x8xf32> into vector<1x3x8xf32> } : vector<1x3x8xi1> -> vector<1x3x8xf32>
+
+transform.sequence failures(propagate) {
+ ^bb1(%arg1: !transform.any_op):
+   %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+   transform.structured.masked_vectorize %0 vector_sizes [1, 3, 8] { vectorize_nd_extract } : !transform.any_op
+}


        


More information about the Mlir-commits mailing list