[Mlir-commits] [mlir] 89b144e - [mlir][linalg] Vectorize tensor.extract using contiguous loads

Andrzej Warzynski llvmlistbot at llvm.org
Wed Feb 22 11:33:01 PST 2023


Author: Andrzej Warzynski
Date: 2023-02-22T19:29:10Z
New Revision: 89b144ece330b363713bec369d2d89dc85f715f5

URL: https://github.com/llvm/llvm-project/commit/89b144ece330b363713bec369d2d89dc85f715f5
DIFF: https://github.com/llvm/llvm-project/commit/89b144ece330b363713bec369d2d89dc85f715f5.diff

LOG: [mlir][linalg] Vectorize tensor.extract using contiguous loads

This patch implements vectorization of tensor.extract for n-D tensor (n
>= 2) using contiguous load operations, i.e. `vector.transfer_read`. This
is a follow-up of https://reviews.llvm.org/D137660 in which gather loads
were used, i.e. `vector.gather`.

It is always safe to use gather load operations when the underlying
memory pattern is contiguous, but not vice-verse. At the moment, the
following conditions have to be met for contiguous loads to be
generated:
  1. The _output tensor_ must be a 1-D vector with the trailing dim > 1,
     e.g. `tensor<1x1x4xi32`,
  2. The trailing dim in the _input tensor_ must be > 1, e.g.
     `tensor<1x1x4i32>` would be fine, but not `tensor<1x4x1xi32>`.
If these conditions are not satisfied, gather loads are generated
instead.

Condition 1 guarantees that the iteration space of the corresponding
`linalg.generic` Op is relatively simple. That makes analysing the
indices for `tensor.extract` rather straightforward.

Condition 2 is mostly there to avoid weird vectorisation patterns
resulting in vectors like: `vector<1x1x1xi32>`. In practice, tensors
like `tensor<1x4x1xi32>` should be collapsed to `tensor<1x4xi32>` before
vectorisation, but that's beyond the scope of this patch.

If needed, both conditions can be relaxed. I've not been able to find a
good motivating example for these, hence skipping. For reference,
`tosa.resize` (lowered to Linalg) was the driving example used here.

As a bonus, the test from "vectorization-unsupported.mlir" is moved to
"vectorization.mlir" with proper CHECK lines added.

Differential Revision: https://reviews.llvm.org/D141998

Co-authored-by: Diego Caballero <diegocaballero at google.com>

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/Dialect/Linalg/vectorization.mlir

Removed: 
    mlir/test/Dialect/Linalg/vectorization-unsupported.mlir


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 75d4595e4f7a3..fda9adf4462fc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -611,11 +611,11 @@ static Value calculateGatherOffset(RewriterBase &rewriter,
 
   const size_t numIndices = extractOp.getIndices().size();
   for (size_t i = 1; i < numIndices; i++) {
+    Value dimIdx = rewriter.create<arith::ConstantIndexOp>(loc, i);
+
     auto dimSize = broadcastIfNeeded(
         rewriter,
-        rewriter.create<arith::ConstantIndexOp>(
-            loc,
-            extractOp.getTensor().getType().cast<ShapedType>().getDimSize(i)),
+        rewriter.create<tensor::DimOp>(loc, extractOp.getTensor(), dimIdx),
         indexVecType.getShape());
 
     offset = rewriter.create<arith::MulIOp>(loc, offset, dimSize);
@@ -630,6 +630,143 @@ static Value calculateGatherOffset(RewriterBase &rewriter,
   return offset;
 }
 
+enum VectorMemoryAccessKind {
+  // TODO: ScalarBroadcast,
+  Contiguous,
+  Gather
+};
+
+/// Check whether /p val can be used for calculating an index for a contiguous
+/// load operation, i.e. whether /p val:
+///   * is invariant with respect to /p linalgOp, i.e. whether it remains
+///   constant for all iterations, and
+///   * increments with the loop iterator (when /p strideZero is false) or is
+///   not affected by the loop indices (/p strideZero is true).
+static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, size_t dim,
+                                    bool strideZero) {
+  auto *block = linalgOp.getBlock();
+
+  // Bail out if this is a block argument for this linalg.generic Op.
+  // TODO: We could try analysing the corresponding affine map here.
+  if (val.dyn_cast<BlockArgument>())
+    return llvm::all_of(block->getArguments(),
+                        [&val](Value v) { return (v != val); });
+
+  Operation *defOp = val.getDefiningOp();
+  assert(defOp && "This is neither a block argument nor an operation result");
+
+  // Given the assumption on the shape of the target tensor, index Op is
+  // either:
+  //    * constant (for non-trailing dims), or
+  //    * increments with stride one together with the trailing dimension
+  // Both cases are fine for contigious loads.
+  if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp))
+    return strideZero ? (indexOp.getDim() != dim) : (indexOp.getDim() == dim);
+
+  auto *ancestor = block->findAncestorOpInBlock(*defOp);
+
+  // Values define outside `linalgOp`.
+  if (!ancestor)
+    return true;
+
+  // Values defined inside `linalgOp`, which are constant.
+  if (dyn_cast<arith::ConstantOp>(ancestor))
+    return true;
+
+  bool result = true;
+  for (auto op : ancestor->getOperands())
+    result &= isContiguousLoadIdx(linalgOp, op, dim, strideZero);
+
+  return result;
+}
+
+/// Check whether the calculation of \p val is based on linalg.index Op with
+/// the dim attribute matching \p dim.
+static bool isBasedOnIndexOp(LinalgOp &linalgOp, Value &val, size_t dim) {
+  auto *block = linalgOp.getBlock();
+  auto targetShape = linalgOp.getStaticLoopRanges();
+
+  if (val.isa<BlockArgument>())
+    return false;
+
+  Operation *defOp = val.getDefiningOp();
+  assert(defOp && "This is neither a block argument nor an operation result");
+
+  if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp))
+    return (indexOp.getDim() == dim);
+
+  auto *ancestor = block->findAncestorOpInBlock(*defOp);
+
+  if (!ancestor)
+    return false;
+
+  bool result = false;
+  for (auto op : ancestor->getOperands())
+    result |= isBasedOnIndexOp(linalgOp, op, dim);
+
+  return result;
+}
+
+/// Check whether \p extractOp would be a gather or a contiguous load Op after
+/// vectorising \p linalgOp. Note that it is always safe to use gather load
+/// operations for contiguous loads (albeit slow), but not vice-versa. When in
+/// doubt, bail out and assume that \p extractOp is a gather load.
+static VectorMemoryAccessKind
+getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
+                                    LinalgOp &linalgOp) {
+
+  auto targetShape = linalgOp.getStaticLoopRanges();
+
+  // Assume that it's a gather load when reading _into_:
+  //    * an n-D vector, like`tensor<1x2x4xi32` or`tensor<2x1x4xi32>`, or
+  //    * a 1-D vector with the trailing dim equal 1, e.g. `tensor<1x4x1xi32`.
+  // TODO: Relax these conditions.
+  if ((llvm::count_if(targetShape,
+                      [](int64_t dimSize) { return dimSize > 1; }) != 1) ||
+      targetShape.back() == 1)
+    return VectorMemoryAccessKind::Gather;
+
+  auto inputShape = extractOp.getTensor().getType().cast<ShapedType>();
+
+  // Assume that it's a gather load when reading _from_ a tensor for which the
+  // trailing dimension is 1, e.g. `tensor<1x4x1xi32>`.
+  // TODO: Relax this condition.
+  if (inputShape.getShape().back() == 1)
+    return VectorMemoryAccessKind::Gather;
+
+  bool isContiguous = true;
+
+  // Iterate over all indices. Analyze whether the way each index is calculate
+  // is suitable for contiguous load operations (e.g. loop invariant).
+  auto indices = extractOp.getIndices();
+  for (auto [i, indexVal] : llvm::enumerate(indices)) {
+    if (inputShape.getShape()[i] == 1) {
+      // This extractOp index must be a loop-invariant constant
+      continue;
+    }
+
+    auto extractOpBottomIdx = indices.size() - 1;
+    auto strideOneDim = targetShape.size() - 1;
+    bool strideZero = (i != extractOpBottomIdx);
+    isContiguous &=
+        isContiguousLoadIdx(linalgOp, indexVal, strideOneDim, strideZero);
+  }
+
+  // The calculation of the trailing index must include the loop index. Given
+  // the assumption on the output tensor (which is defined by the iteration
+  // space), only the trailing dim matters.
+  auto extractOpTrailingIdx = indices.back();
+  isContiguous &=
+      isBasedOnIndexOp(linalgOp, extractOpTrailingIdx, targetShape.size() - 1);
+
+  if (isContiguous) {
+    LDBG("Found contigous load: " << extractOp);
+    return VectorMemoryAccessKind::Contiguous;
+  }
+
+  return VectorMemoryAccessKind::Gather;
+}
+
 /// Helper function to vectorize the tensor.extract operations. Returns
 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
 /// should map the produced operations. This function is meant to be used as a
@@ -660,15 +797,64 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
       extractOp.getIndices().size(),
       rewriter.create<arith::ConstantIndexOp>(loc, 0));
 
-  Value offset = calculateGatherOffset(rewriter, extractOp, bvm, targetShape);
+  VectorMemoryAccessKind memAccessKind =
+      getTensorExtractMemoryAccessPattern(extractOp, linalgOp);
+
+  // 1. Handle gather access
+  if (memAccessKind == VectorMemoryAccessKind::Gather) {
+    Value offset = calculateGatherOffset(rewriter, extractOp, bvm, targetShape);
+
+    // Generate the gather load
+    Operation *gatherOp = rewriter.create<vector::GatherOp>(
+        loc, resultType, extractOp.getTensor(), baseIndices, offset,
+        maskConstantOp, passThruConstantOp);
+    gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
+
+    LDBG("Vectorised as gather load: " << extractOp);
+    return VectorizationResult{VectorizationStatus::NewOp, gatherOp};
+  }
+
+  // 2. Handle contiguous access.
+  SmallVector<Value> transferReadIdxs;
+  auto resTrailingDim = resultType.getShape().back();
+  auto zero = rewriter.create<arith::ConstantOp>(
+      loc, rewriter.getI32Type(), rewriter.getZeroAttr(rewriter.getI32Type()));
+
+  // Collect indices for `vector.transfer_read`. At this point, the indices will
+  // either be scalars or would have been broadcast to vectors matching the
+  // result type. For indices that are vectors, there are two options:
+  //    * for non-trailing indices, all elements are identical (contiguous
+  //      loads are identified by looking for non-trailing indices that are
+  //      invariant with respect to the corresponding linalg.generic), or
+  //    * for trailing indices, the index vector will contain values with stride
+  //      one, but for `vector.transfer_read` only the first (i.e. 0th) index is
+  //      needed.
+  // This means that
+  //   * for scalar indices - just re-use it,
+  //   * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
+  //    (0th) element and use that.
+  for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
+    auto idx = bvm.lookup(extractOp.getIndices()[i]);
+    if (idx.getType().isIndex()) {
+      transferReadIdxs.push_back(idx);
+      continue;
+    }
+
+    auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
+        loc, VectorType::get({resTrailingDim}, rewriter.getIndexType()),
+        bvm.lookup(extractOp.getIndices()[i]));
+    transferReadIdxs.push_back(
+        rewriter.create<vector::ExtractElementOp>(loc, indexAs1dVector, zero));
+  }
+
+  // `tensor.extract_element` is always in-bounds, hence the following holds.
+  SmallVector<bool> inBounds(resultType.getRank(), true);
 
-  // Generate the gather load
-  Operation *gatherOp = rewriter.create<vector::GatherOp>(
-      loc, resultType, extractOp.getTensor(), baseIndices, offset,
-      maskConstantOp, passThruConstantOp);
-  gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
+  auto transferReadOp = rewriter.create<vector::TransferReadOp>(
+      loc, resultType, extractOp.getTensor(), transferReadIdxs, inBounds);
 
-  return VectorizationResult{VectorizationStatus::NewOp, gatherOp};
+  LDBG("Vectorised as contiguous load: " << extractOp);
+  return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
 }
 
 /// Emit reduction operations if the shapes of the value to reduce is 
diff erent

diff  --git a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
deleted file mode 100644
index aa1d6e34657e7..0000000000000
--- a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
+++ /dev/null
@@ -1,29 +0,0 @@
-// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file -verify-diagnostics
-
-// Masked vectorisation of `tensor.extract`:
-//   * requires the `{ vectorize_nd_extract }` attribute,
-//   * has not been implemented yet (hence the attribute is absent).
-// TOOD: Implement masked vectorization for `tensor.extract`
-
-#map1 = affine_map<(d0, d1) -> (d0, d1)>
-func.func @extract_masked_vectorize(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
-  %c0 = arith.constant 1 : index
-  %c1 = arith.constant 2 : index
-  // expected-error at +1 {{failed to vectorize op}}
-  %2 = linalg.generic {
-    indexing_maps = [#map1],
-    iterator_types = ["parallel", "parallel"]
-  } outs(%arg1 : tensor<?x?xf32>) {
-  ^bb0(%arg3: f32):
-    %7 = tensor.extract %arg0[%c0, %c1] : tensor<?x?xf32>
-    linalg.yield %7 : f32
-  } -> tensor<?x?xf32>
-  return %2 : tensor<?x?xf32>
-}
-
-
-transform.sequence failures(propagate) {
- ^bb1(%arg1: !pdl.operation):
-   %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!pdl.operation) -> !pdl.operation
-   transform.structured.masked_vectorize %0 vector_sizes [3, 3]
- }

diff  --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index a43fd7a9fc5c5..a43cf9da514fe 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -1584,7 +1584,6 @@ func.func @vectorize_nd_tensor_extract_constant_idx(%arg0: tensor<3x3xf32>, %arg
     iterator_types = ["parallel", "parallel", "parallel"]
   } outs(%arg2 : tensor<1x1x3xf32>) {
   ^bb0(%arg4: f32):
-    %3 = linalg.index 2 : index
     %7 = tensor.extract %arg0[%c0, %c1] : tensor<3x3xf32>
     linalg.yield %7 : f32
   } -> tensor<1x1x3xf32>
@@ -1613,7 +1612,7 @@ transform.sequence failures(propagate) {
 // -----
 
 #map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-func.func @vectorize_nd_tensor_extract_idx_from_iteration_index(%arg0: tensor<3x3x3xf32>, %arg2: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> {
+func.func @vectorize_nd_tensor_extract_transfer_read_basic(%arg0: tensor<3x3x3xf32>, %arg2: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> {
   %1 = linalg.generic {
     indexing_maps = [#map1],
     iterator_types = ["parallel", "parallel", "parallel"]
@@ -1628,16 +1627,19 @@ func.func @vectorize_nd_tensor_extract_idx_from_iteration_index(%arg0: tensor<3x
   return %1 : tensor<1x1x3xf32>
 }
 
-// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_idx_from_iteration_index
+// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_read_basic
 // CHECK-SAME: %[[ARG0:.*]]: tensor<3x3x3xf32>
 // CHECK-SAME: %[[ARG1:.*]]: tensor<1x1x3xf32>
-// CHECK:   %[[INDICES:.*]] = arith.constant dense<[0, 1, 2]> : vector<3xindex>
-// CHECK:   %[[MASK:.*]] = arith.constant dense<true> : vector<1x1x3xi1>
-// CHECK:   %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x3xf32>
+// CHECK:   %[[CST:.*]] = arith.constant dense<0> : vector<1x1x3xindex>
+// CHECK:   %[[C0_i32:.*]] = arith.constant 0 : i32
 // CHECK:   %[[C0:.*]] = arith.constant 0 : index
-// CHECK:   %[[B:.*]] = vector.broadcast %[[INDICES]] : vector<3xindex> to vector<1x1x3xindex>
-// CHECK:   %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[B]]], %[[MASK]], %[[PASSTHRU]] : tensor<3x3x3xf32>, vector<1x1x3xindex>, vector<1x1x3xi1>, vector<1x1x3xf32> into vector<1x1x3xf32>
-// CHECK:   vector.transfer_write %[[GATHER]]
+// CHECK:   %[[CST_0:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:   %[[IDX_VEC0:.*]] = vector.shape_cast %[[CST]] : vector<1x1x3xindex> to vector<3xindex>
+// CHECK:   %[[IDX1:.*]] = vector.extractelement %[[IDX_VEC0]][%[[C0_i32]] : i32] : vector<3xindex>
+// CHECK:   %[[IDX_VEC:.*]] = vector.shape_cast %[[CST]] : vector<1x1x3xindex> to vector<3xindex>
+// CHECK:   %[[IDX2:.*]] = vector.extractelement %[[IDX_VEC]][%[[C0_i32]] : i32] : vector<3xindex>
+// CHECK:   %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[IDX1]], %[[IDX2]], %[[C0:.*]]], %[[CST_0]] {in_bounds = [true, true, true]} : tensor<3x3x3xf32>, vector<1x1x3xf32>
+// CHECK:   vector.transfer_write %[[READ]], %[[ARG1]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x3xf32>, tensor<1x1x3xf32>
 
 transform.sequence failures(propagate) {
 ^bb1(%arg1: !pdl.operation):
@@ -1646,6 +1648,56 @@ transform.sequence failures(propagate) {
   %2 = transform.structured.vectorize %1 { vectorize_nd_extract }
 }
 
+ // -----
+
+func.func @vectorize_nd_tensor_extract_transfer_read_complex(%6: tensor<45x80x16xf32>, %arg0: index, %arg2: index, %arg1: index, %arg4: index, %extracted_slice : tensor<1x4xf32>) -> tensor<1x4xf32> {
+  %c79 = arith.constant 79 : index
+  %25 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]
+  } outs(%extracted_slice : tensor<1x4xf32>) {
+  ^bb0(%out: f32):
+    %26 = linalg.index 0 : index
+    %27 = arith.addi %arg0, %26 : index
+    %28 = arith.addi %27, %arg2 : index
+    %29 = linalg.index 1 : index
+    %30 = arith.addi %arg1, %29 : index
+    %31 = arith.addi %30, %arg4 : index
+    %extracted = tensor.extract %6[%28, %c79, %31] : tensor<45x80x16xf32>
+    linalg.yield %extracted : f32
+  } -> tensor<1x4xf32>
+  return %25 : tensor<1x4xf32>
+}
+
+// CHECK-LABEL:   func.func @vectorize_nd_tensor_extract_transfer_read_complex
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<45x80x16xf32>,
+// CHECK-SAME:      {{.*}}: index,
+// CHECK-SAME:      %[[VAL_5:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
+// CHECK:           %[[VAL_6:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
+// CHECK:           %[[VAL_7:.*]] = arith.constant 0 : i32
+// CHECK:           %[[VAL_8:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[VAL_9:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_10:.*]] = arith.constant 79 : index
+// CHECK:           %[[VAL_11:.*]] = vector.broadcast %{{.*}} : index to vector<1x4xindex>
+// CHECK:           %[[VAL_12:.*]] = vector.broadcast %{{.*}} : index to vector<1x4xindex>
+// CHECK:           %[[VAL_13:.*]] = arith.addi %[[VAL_11]], %[[VAL_12]] : vector<1x4xindex>
+// CHECK:           %[[VAL_14:.*]] = vector.broadcast %{{.*}} : index to vector<4xindex>
+// CHECK:           %[[VAL_15:.*]] = arith.addi %[[VAL_14]], %[[VAL_6]] : vector<4xindex>
+// CHECK:           %[[VAL_16:.*]] = vector.broadcast %{{.*}} : index to vector<4xindex>
+// CHECK:           %[[VAL_17:.*]] = arith.addi %[[VAL_15]], %[[VAL_16]] : vector<4xindex>
+// CHECK:           %[[VAL_18:.*]] = vector.shape_cast %[[VAL_13]] : vector<1x4xindex> to vector<4xindex>
+// CHECK:           %[[VAL_19:.*]] = vector.extractelement %[[VAL_18]]{{\[}}%[[VAL_7]] : i32] : vector<4xindex>
+// CHECK:           %[[VAL_20:.*]] = vector.extractelement %[[VAL_17]]{{\[}}%[[VAL_7]] : i32] : vector<4xindex>
+// CHECK:           %[[VAL_21:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_19]], %[[VAL_10]], %[[VAL_20]]], %[[VAL_8]] {in_bounds = [true, true]} : tensor<45x80x16xf32>, vector<1x4xf32>
+// CHECK:           %[[VAL_22:.*]] = vector.transfer_write %[[VAL_21]], %[[VAL_5]]{{\[}}%[[VAL_9]], %[[VAL_9]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
+
+transform.sequence failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+   %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+   %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation
+   %2 = transform.structured.vectorize %1 { vectorize_nd_extract }
+ }
+
 // -----
 
 #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
@@ -2119,6 +2171,56 @@ transform.sequence failures(propagate) {
 
 // -----
 
+#map1 = affine_map<(d0, d1) -> (d0, d1)>
+func.func @extract_masked_vectorize(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %c0 = arith.constant 1 : index
+  %c1 = arith.constant 2 : index
+  %2 = linalg.generic {
+    indexing_maps = [#map1],
+    iterator_types = ["parallel", "parallel"]
+  } outs(%arg1 : tensor<?x?xf32>) {
+  ^bb0(%arg3: f32):
+    %7 = tensor.extract %arg0[%c0, %c1] : tensor<?x?xf32>
+    linalg.yield %7 : f32
+  } -> tensor<?x?xf32>
+  return %2 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL:   func.func @extract_masked_vectorize(
+// CHECK-SAME:                                        %[[VAL_0:.*]]: tensor<?x?xf32>,
+// CHECK-SAME:                                        %[[VAL_1:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+// CHECK:           %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_3:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_4:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_5:.*]] = tensor.dim %[[VAL_1]], %[[VAL_4]] : tensor<?x?xf32>
+// CHECK:           %[[VAL_6:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_7:.*]] = tensor.dim %[[VAL_1]], %[[VAL_6]] : tensor<?x?xf32>
+// CHECK:           %[[VAL_8:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[VAL_10:.*]] = vector.create_mask %[[VAL_5]], %[[VAL_7]] : vector<3x3xi1>
+// CHECK:           %[[VAL_11:.*]] = vector.mask %[[VAL_10]] { vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_8]], %[[VAL_8]]], %[[VAL_9]] {in_bounds = [true, true]} : tensor<?x?xf32>, vector<3x3xf32> } : vector<3x3xi1> -> vector<3x3xf32>
+// CHECK:           %[[VAL_12:.*]] = arith.constant dense<true> : vector<3x3xi1>
+// CHECK:           %[[VAL_13:.*]] = arith.constant dense<0.000000e+00> : vector<3x3xf32>
+// CHECK:           %[[VAL_14:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_15:.*]] = arith.constant dense<1> : vector<3x3xindex>
+// CHECK:           %[[VAL_16:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_17:.*]] = tensor.dim %[[VAL_0]], %[[VAL_16]] : tensor<?x?xf32>
+// CHECK:           %[[VAL_18:.*]] = vector.broadcast %[[VAL_17]] : index to vector<3x3xindex>
+// CHECK:           %[[VAL_19:.*]] = arith.muli %[[VAL_15]], %[[VAL_18]] : vector<3x3xindex>
+// CHECK:           %[[VAL_20:.*]] = arith.constant dense<2> : vector<3x3xindex>
+// CHECK:           %[[VAL_21:.*]] = arith.addi %[[VAL_20]], %[[VAL_19]] : vector<3x3xindex>
+// CHECK:           %[[VAL_22:.*]] = vector.mask %[[VAL_10]] { vector.gather %[[VAL_0]]{{\[}}%[[VAL_14]], %[[VAL_14]]] {{\[}}%[[VAL_21]]], %[[VAL_12]], %[[VAL_13]] : tensor<?x?xf32>, vector<3x3xindex>, vector<3x3xi1>, vector<3x3xf32> into vector<3x3xf32> } : vector<3x3xi1> -> vector<3x3xf32>
+// CHECK:           %[[VAL_23:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_24:.*]] = vector.mask %[[VAL_10]] { vector.transfer_write %[[VAL_22]], %[[VAL_1]]{{\[}}%[[VAL_23]], %[[VAL_23]]] {in_bounds = [true, true]} : vector<3x3xf32>, tensor<?x?xf32> } : vector<3x3xi1> -> tensor<?x?xf32>
+
+transform.sequence failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+   %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+   transform.structured.masked_vectorize %0 vector_sizes [3, 3] { vectorize_nd_extract }
+ }
+
+// -----
+
 func.func @do_not_generate_masks(%arg0: tensor<8x32xf32>,
                                  %arg1: tensor<8x32xf32>,
                                  %arg2: tensor<8x32xf32>) -> tensor<8x32xf32> {


        


More information about the Mlir-commits mailing list