[Mlir-commits] [mlir] 2712b28 - [mlir][linalg] Vectorize 0-d tensor extract

Andrzej Warzynski llvmlistbot at llvm.org
Thu Jul 6 00:34:57 PDT 2023


Author: Andrzej Warzynski
Date: 2023-07-06T08:34:51+01:00
New Revision: 2712b2805b47f10b3864ab19a4016ea306126ad7

URL: https://github.com/llvm/llvm-project/commit/2712b2805b47f10b3864ab19a4016ea306126ad7
DIFF: https://github.com/llvm/llvm-project/commit/2712b2805b47f10b3864ab19a4016ea306126ad7.diff

LOG: [mlir][linalg] Vectorize 0-d tensor extract

This patch adds the missing logic to vectorise `tensor.extract` for 0-d
tensors.

Fixes #63688

Differential Revision: https://reviews.llvm.org/D154518

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 0321220484aeb2..0a77eccefbf38a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -745,8 +745,12 @@ tensorExtractVectorizationPrecondition(Operation *op, bool vectorizeNDExtract) {
   if (extractOp.getIndices().size() != 1 && !vectorizeNDExtract)
     return failure();
 
-  if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
-    return failure();
+  // Check the index type, but only for non 0-d tensors (for which we do need
+  // access indices).
+  if (not extractOp.getIndices().empty()) {
+    if (!VectorType::isValidElementType(extractOp.getIndices()[0].getType()))
+      return failure();
+  }
 
   if (llvm::any_of(extractOp->getResultTypes(), [](Type type) {
         return !VectorType::isValidElementType(type);
@@ -919,6 +923,12 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
                                     LinalgOp &linalgOp) {
 
   auto targetShape = linalgOp.getStaticLoopRanges();
+  auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
+
+  // 0. Is this a 0-D vector? If yes then this is a scalar broadcast.
+  if (inputShape.getShape().empty())
+    return VectorMemoryAccessKind::ScalarBroadcast;
+
 
   // 1. Assume that it's a gather load when reading _into_:
   //    * an n-D vector, like`tensor<1x2x4xi32` or`tensor<2x1x4xi32>`, or
@@ -929,7 +939,6 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
       targetShape.back() == 1)
     return VectorMemoryAccessKind::Gather;
 
-  auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
 
   // 2. Assume that it's a gather load when reading _from_ a tensor for which
   // the trailing dimension is 1, e.g. `tensor<1x4x1xi32>`.

diff  --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index cf1c33bf5eec85..7c75e0ff3044d9 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -499,3 +499,30 @@ transform.sequence failures(propagate) {
    %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
    %2 = transform.structured.vectorize %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op
  }
+
+// -----
+
+#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+func.func @vectorize_0d_tensor_extract(%arg0: tensor<f32>, %arg2: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> {
+  %2 = linalg.generic {
+    indexing_maps = [#map1],
+    iterator_types = ["parallel", "parallel", "parallel"]
+  } outs(%arg2 : tensor<1x1x3xf32>) {
+  ^bb0(%arg4: f32):
+    %7 = tensor.extract %arg0[] : tensor<f32>
+    linalg.yield %7 : f32
+  } -> tensor<1x1x3xf32>
+  return %2 : tensor<1x1x3xf32>
+}
+
+// CHECK-LABEL:   func.func @vectorize_0d_tensor_extract(
+// CHECK-SAME:     %[[ARG_0:.*]]: tensor<f32>
+// CHECK:           %[[EXTRACT:.*]] = tensor.extract %[[ARG_0]][] : tensor<f32>
+// CHECK:           vector.broadcast %[[EXTRACT]] : f32 to vector<1x1x3xf32>
+
+transform.sequence failures(propagate) {
+ ^bb1(%arg1: !transform.any_op):
+   %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+   %1 = get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+   %2 = transform.structured.vectorize %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op
+ }


        


More information about the Mlir-commits mailing list