[Mlir-commits] [mlir] e28bbfe - Revert "[mlir][linalg] Vectorize tensor.extract using contiguous loads"

Benjamin Kramer llvmlistbot at llvm.org
Tue Feb 28 04:33:22 PST 2023


Author: Benjamin Kramer
Date: 2023-02-28T13:33:11+01:00
New Revision: e28bbfea5d482c1825b1799c57aedff4e0116619

URL: https://github.com/llvm/llvm-project/commit/e28bbfea5d482c1825b1799c57aedff4e0116619
DIFF: https://github.com/llvm/llvm-project/commit/e28bbfea5d482c1825b1799c57aedff4e0116619.diff

LOG: Revert "[mlir][linalg] Vectorize tensor.extract using contiguous loads"

This reverts commit 89b144ece330b363713bec369d2d89dc85f715f5. See
https://reviews.llvm.org/D141998 for a test case where this goes wrong.

Added: 
    mlir/test/Dialect/Linalg/vectorization-unsupported.mlir

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/Dialect/Linalg/vectorization.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 4ab82825d1a7..11ba3712d9d1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -611,11 +611,11 @@ static Value calculateGatherOffset(RewriterBase &rewriter,
 
   const size_t numIndices = extractOp.getIndices().size();
   for (size_t i = 1; i < numIndices; i++) {
-    Value dimIdx = rewriter.create<arith::ConstantIndexOp>(loc, i);
-
     auto dimSize = broadcastIfNeeded(
         rewriter,
-        rewriter.create<tensor::DimOp>(loc, extractOp.getTensor(), dimIdx),
+        rewriter.create<arith::ConstantIndexOp>(
+            loc,
+            extractOp.getTensor().getType().cast<ShapedType>().getDimSize(i)),
         indexVecType.getShape());
 
     offset = rewriter.create<arith::MulIOp>(loc, offset, dimSize);
@@ -630,143 +630,6 @@ static Value calculateGatherOffset(RewriterBase &rewriter,
   return offset;
 }
 
-enum VectorMemoryAccessKind {
-  // TODO: ScalarBroadcast,
-  Contiguous,
-  Gather
-};
-
-/// Check whether /p val can be used for calculating an index for a contiguous
-/// load operation, i.e. whether /p val:
-///   * is invariant with respect to /p linalgOp, i.e. whether it remains
-///   constant for all iterations, and
-///   * increments with the loop iterator (when /p strideZero is false) or is
-///   not affected by the loop indices (/p strideZero is true).
-static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, size_t dim,
-                                    bool strideZero) {
-  auto *block = linalgOp.getBlock();
-
-  // Bail out if this is a block argument for this linalg.generic Op.
-  // TODO: We could try analysing the corresponding affine map here.
-  if (val.dyn_cast<BlockArgument>())
-    return llvm::all_of(block->getArguments(),
-                        [&val](Value v) { return (v != val); });
-
-  Operation *defOp = val.getDefiningOp();
-  assert(defOp && "This is neither a block argument nor an operation result");
-
-  // Given the assumption on the shape of the target tensor, index Op is
-  // either:
-  //    * constant (for non-trailing dims), or
-  //    * increments with stride one together with the trailing dimension
-  // Both cases are fine for contigious loads.
-  if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp))
-    return strideZero ? (indexOp.getDim() != dim) : (indexOp.getDim() == dim);
-
-  auto *ancestor = block->findAncestorOpInBlock(*defOp);
-
-  // Values define outside `linalgOp`.
-  if (!ancestor)
-    return true;
-
-  // Values defined inside `linalgOp`, which are constant.
-  if (dyn_cast<arith::ConstantOp>(ancestor))
-    return true;
-
-  bool result = true;
-  for (auto op : ancestor->getOperands())
-    result &= isContiguousLoadIdx(linalgOp, op, dim, strideZero);
-
-  return result;
-}
-
-/// Check whether the calculation of \p val is based on linalg.index Op with
-/// the dim attribute matching \p dim.
-static bool isBasedOnIndexOp(LinalgOp &linalgOp, Value &val, size_t dim) {
-  auto *block = linalgOp.getBlock();
-  auto targetShape = linalgOp.getStaticLoopRanges();
-
-  if (val.isa<BlockArgument>())
-    return false;
-
-  Operation *defOp = val.getDefiningOp();
-  assert(defOp && "This is neither a block argument nor an operation result");
-
-  if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp))
-    return (indexOp.getDim() == dim);
-
-  auto *ancestor = block->findAncestorOpInBlock(*defOp);
-
-  if (!ancestor)
-    return false;
-
-  bool result = false;
-  for (auto op : ancestor->getOperands())
-    result |= isBasedOnIndexOp(linalgOp, op, dim);
-
-  return result;
-}
-
-/// Check whether \p extractOp would be a gather or a contiguous load Op after
-/// vectorising \p linalgOp. Note that it is always safe to use gather load
-/// operations for contiguous loads (albeit slow), but not vice-versa. When in
-/// doubt, bail out and assume that \p extractOp is a gather load.
-static VectorMemoryAccessKind
-getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
-                                    LinalgOp &linalgOp) {
-
-  auto targetShape = linalgOp.getStaticLoopRanges();
-
-  // Assume that it's a gather load when reading _into_:
-  //    * an n-D vector, like`tensor<1x2x4xi32` or`tensor<2x1x4xi32>`, or
-  //    * a 1-D vector with the trailing dim equal 1, e.g. `tensor<1x4x1xi32`.
-  // TODO: Relax these conditions.
-  if ((llvm::count_if(targetShape,
-                      [](int64_t dimSize) { return dimSize > 1; }) != 1) ||
-      targetShape.back() == 1)
-    return VectorMemoryAccessKind::Gather;
-
-  auto inputShape = extractOp.getTensor().getType().cast<ShapedType>();
-
-  // Assume that it's a gather load when reading _from_ a tensor for which the
-  // trailing dimension is 1, e.g. `tensor<1x4x1xi32>`.
-  // TODO: Relax this condition.
-  if (inputShape.getShape().back() == 1)
-    return VectorMemoryAccessKind::Gather;
-
-  bool isContiguous = true;
-
-  // Iterate over all indices. Analyze whether the way each index is calculate
-  // is suitable for contiguous load operations (e.g. loop invariant).
-  auto indices = extractOp.getIndices();
-  for (auto [i, indexVal] : llvm::enumerate(indices)) {
-    if (inputShape.getShape()[i] == 1) {
-      // This extractOp index must be a loop-invariant constant
-      continue;
-    }
-
-    auto extractOpBottomIdx = indices.size() - 1;
-    auto strideOneDim = targetShape.size() - 1;
-    bool strideZero = (i != extractOpBottomIdx);
-    isContiguous &=
-        isContiguousLoadIdx(linalgOp, indexVal, strideOneDim, strideZero);
-  }
-
-  // The calculation of the trailing index must include the loop index. Given
-  // the assumption on the output tensor (which is defined by the iteration
-  // space), only the trailing dim matters.
-  auto extractOpTrailingIdx = indices.back();
-  isContiguous &=
-      isBasedOnIndexOp(linalgOp, extractOpTrailingIdx, targetShape.size() - 1);
-
-  if (isContiguous) {
-    LDBG("Found contigous load: " << extractOp);
-    return VectorMemoryAccessKind::Contiguous;
-  }
-
-  return VectorMemoryAccessKind::Gather;
-}
-
 /// Helper function to vectorize the tensor.extract operations. Returns
 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
 /// should map the produced operations. This function is meant to be used as a
@@ -797,64 +660,15 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
       extractOp.getIndices().size(),
       rewriter.create<arith::ConstantIndexOp>(loc, 0));
 
-  VectorMemoryAccessKind memAccessKind =
-      getTensorExtractMemoryAccessPattern(extractOp, linalgOp);
-
-  // 1. Handle gather access
-  if (memAccessKind == VectorMemoryAccessKind::Gather) {
-    Value offset = calculateGatherOffset(rewriter, extractOp, bvm, targetShape);
-
-    // Generate the gather load
-    Operation *gatherOp = rewriter.create<vector::GatherOp>(
-        loc, resultType, extractOp.getTensor(), baseIndices, offset,
-        maskConstantOp, passThruConstantOp);
-    gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
-
-    LDBG("Vectorised as gather load: " << extractOp);
-    return VectorizationResult{VectorizationStatus::NewOp, gatherOp};
-  }
-
-  // 2. Handle contiguous access.
-  SmallVector<Value> transferReadIdxs;
-  auto resTrailingDim = resultType.getShape().back();
-  auto zero = rewriter.create<arith::ConstantOp>(
-      loc, rewriter.getI32Type(), rewriter.getZeroAttr(rewriter.getI32Type()));
-
-  // Collect indices for `vector.transfer_read`. At this point, the indices will
-  // either be scalars or would have been broadcast to vectors matching the
-  // result type. For indices that are vectors, there are two options:
-  //    * for non-trailing indices, all elements are identical (contiguous
-  //      loads are identified by looking for non-trailing indices that are
-  //      invariant with respect to the corresponding linalg.generic), or
-  //    * for trailing indices, the index vector will contain values with stride
-  //      one, but for `vector.transfer_read` only the first (i.e. 0th) index is
-  //      needed.
-  // This means that
-  //   * for scalar indices - just re-use it,
-  //   * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
-  //    (0th) element and use that.
-  for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
-    auto idx = bvm.lookup(extractOp.getIndices()[i]);
-    if (idx.getType().isIndex()) {
-      transferReadIdxs.push_back(idx);
-      continue;
-    }
-
-    auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
-        loc, VectorType::get({resTrailingDim}, rewriter.getIndexType()),
-        bvm.lookup(extractOp.getIndices()[i]));
-    transferReadIdxs.push_back(
-        rewriter.create<vector::ExtractElementOp>(loc, indexAs1dVector, zero));
-  }
-
-  // `tensor.extract_element` is always in-bounds, hence the following holds.
-  SmallVector<bool> inBounds(resultType.getRank(), true);
+  Value offset = calculateGatherOffset(rewriter, extractOp, bvm, targetShape);
 
-  auto transferReadOp = rewriter.create<vector::TransferReadOp>(
-      loc, resultType, extractOp.getTensor(), transferReadIdxs, inBounds);
+  // Generate the gather load
+  Operation *gatherOp = rewriter.create<vector::GatherOp>(
+      loc, resultType, extractOp.getTensor(), baseIndices, offset,
+      maskConstantOp, passThruConstantOp);
+  gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
 
-  LDBG("Vectorised as contiguous load: " << extractOp);
-  return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
+  return VectorizationResult{VectorizationStatus::NewOp, gatherOp};
 }
 
 /// Emit reduction operations if the shapes of the value to reduce is 
diff erent

diff  --git a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
new file mode 100644
index 000000000000..aa1d6e34657e
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
@@ -0,0 +1,29 @@
+// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file -verify-diagnostics
+
+// Masked vectorisation of `tensor.extract`:
+//   * requires the `{ vectorize_nd_extract }` attribute,
+//   * has not been implemented yet (hence the attribute is absent).
+// TOOD: Implement masked vectorization for `tensor.extract`
+
+#map1 = affine_map<(d0, d1) -> (d0, d1)>
+func.func @extract_masked_vectorize(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %c0 = arith.constant 1 : index
+  %c1 = arith.constant 2 : index
+  // expected-error at +1 {{failed to vectorize op}}
+  %2 = linalg.generic {
+    indexing_maps = [#map1],
+    iterator_types = ["parallel", "parallel"]
+  } outs(%arg1 : tensor<?x?xf32>) {
+  ^bb0(%arg3: f32):
+    %7 = tensor.extract %arg0[%c0, %c1] : tensor<?x?xf32>
+    linalg.yield %7 : f32
+  } -> tensor<?x?xf32>
+  return %2 : tensor<?x?xf32>
+}
+
+
+transform.sequence failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+   %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+   transform.structured.masked_vectorize %0 vector_sizes [3, 3]
+ }

diff  --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index a43cf9da514f..a43fd7a9fc5c 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -1584,6 +1584,7 @@ func.func @vectorize_nd_tensor_extract_constant_idx(%arg0: tensor<3x3xf32>, %arg
     iterator_types = ["parallel", "parallel", "parallel"]
   } outs(%arg2 : tensor<1x1x3xf32>) {
   ^bb0(%arg4: f32):
+    %3 = linalg.index 2 : index
     %7 = tensor.extract %arg0[%c0, %c1] : tensor<3x3xf32>
     linalg.yield %7 : f32
   } -> tensor<1x1x3xf32>
@@ -1612,7 +1613,7 @@ transform.sequence failures(propagate) {
 // -----
 
 #map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-func.func @vectorize_nd_tensor_extract_transfer_read_basic(%arg0: tensor<3x3x3xf32>, %arg2: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> {
+func.func @vectorize_nd_tensor_extract_idx_from_iteration_index(%arg0: tensor<3x3x3xf32>, %arg2: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> {
   %1 = linalg.generic {
     indexing_maps = [#map1],
     iterator_types = ["parallel", "parallel", "parallel"]
@@ -1627,19 +1628,16 @@ func.func @vectorize_nd_tensor_extract_transfer_read_basic(%arg0: tensor<3x3x3xf
   return %1 : tensor<1x1x3xf32>
 }
 
-// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_read_basic
+// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_idx_from_iteration_index
 // CHECK-SAME: %[[ARG0:.*]]: tensor<3x3x3xf32>
 // CHECK-SAME: %[[ARG1:.*]]: tensor<1x1x3xf32>
-// CHECK:   %[[CST:.*]] = arith.constant dense<0> : vector<1x1x3xindex>
-// CHECK:   %[[C0_i32:.*]] = arith.constant 0 : i32
+// CHECK:   %[[INDICES:.*]] = arith.constant dense<[0, 1, 2]> : vector<3xindex>
+// CHECK:   %[[MASK:.*]] = arith.constant dense<true> : vector<1x1x3xi1>
+// CHECK:   %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x3xf32>
 // CHECK:   %[[C0:.*]] = arith.constant 0 : index
-// CHECK:   %[[CST_0:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:   %[[IDX_VEC0:.*]] = vector.shape_cast %[[CST]] : vector<1x1x3xindex> to vector<3xindex>
-// CHECK:   %[[IDX1:.*]] = vector.extractelement %[[IDX_VEC0]][%[[C0_i32]] : i32] : vector<3xindex>
-// CHECK:   %[[IDX_VEC:.*]] = vector.shape_cast %[[CST]] : vector<1x1x3xindex> to vector<3xindex>
-// CHECK:   %[[IDX2:.*]] = vector.extractelement %[[IDX_VEC]][%[[C0_i32]] : i32] : vector<3xindex>
-// CHECK:   %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[IDX1]], %[[IDX2]], %[[C0:.*]]], %[[CST_0]] {in_bounds = [true, true, true]} : tensor<3x3x3xf32>, vector<1x1x3xf32>
-// CHECK:   vector.transfer_write %[[READ]], %[[ARG1]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x3xf32>, tensor<1x1x3xf32>
+// CHECK:   %[[B:.*]] = vector.broadcast %[[INDICES]] : vector<3xindex> to vector<1x1x3xindex>
+// CHECK:   %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[B]]], %[[MASK]], %[[PASSTHRU]] : tensor<3x3x3xf32>, vector<1x1x3xindex>, vector<1x1x3xi1>, vector<1x1x3xf32> into vector<1x1x3xf32>
+// CHECK:   vector.transfer_write %[[GATHER]]
 
 transform.sequence failures(propagate) {
 ^bb1(%arg1: !pdl.operation):
@@ -1648,56 +1646,6 @@ transform.sequence failures(propagate) {
   %2 = transform.structured.vectorize %1 { vectorize_nd_extract }
 }
 
- // -----
-
-func.func @vectorize_nd_tensor_extract_transfer_read_complex(%6: tensor<45x80x16xf32>, %arg0: index, %arg2: index, %arg1: index, %arg4: index, %extracted_slice : tensor<1x4xf32>) -> tensor<1x4xf32> {
-  %c79 = arith.constant 79 : index
-  %25 = linalg.generic {
-    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
-    iterator_types = ["parallel", "parallel"]
-  } outs(%extracted_slice : tensor<1x4xf32>) {
-  ^bb0(%out: f32):
-    %26 = linalg.index 0 : index
-    %27 = arith.addi %arg0, %26 : index
-    %28 = arith.addi %27, %arg2 : index
-    %29 = linalg.index 1 : index
-    %30 = arith.addi %arg1, %29 : index
-    %31 = arith.addi %30, %arg4 : index
-    %extracted = tensor.extract %6[%28, %c79, %31] : tensor<45x80x16xf32>
-    linalg.yield %extracted : f32
-  } -> tensor<1x4xf32>
-  return %25 : tensor<1x4xf32>
-}
-
-// CHECK-LABEL:   func.func @vectorize_nd_tensor_extract_transfer_read_complex
-// CHECK-SAME:      %[[VAL_0:.*]]: tensor<45x80x16xf32>,
-// CHECK-SAME:      {{.*}}: index,
-// CHECK-SAME:      %[[VAL_5:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
-// CHECK:           %[[VAL_6:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
-// CHECK:           %[[VAL_7:.*]] = arith.constant 0 : i32
-// CHECK:           %[[VAL_8:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:           %[[VAL_9:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_10:.*]] = arith.constant 79 : index
-// CHECK:           %[[VAL_11:.*]] = vector.broadcast %{{.*}} : index to vector<1x4xindex>
-// CHECK:           %[[VAL_12:.*]] = vector.broadcast %{{.*}} : index to vector<1x4xindex>
-// CHECK:           %[[VAL_13:.*]] = arith.addi %[[VAL_11]], %[[VAL_12]] : vector<1x4xindex>
-// CHECK:           %[[VAL_14:.*]] = vector.broadcast %{{.*}} : index to vector<4xindex>
-// CHECK:           %[[VAL_15:.*]] = arith.addi %[[VAL_14]], %[[VAL_6]] : vector<4xindex>
-// CHECK:           %[[VAL_16:.*]] = vector.broadcast %{{.*}} : index to vector<4xindex>
-// CHECK:           %[[VAL_17:.*]] = arith.addi %[[VAL_15]], %[[VAL_16]] : vector<4xindex>
-// CHECK:           %[[VAL_18:.*]] = vector.shape_cast %[[VAL_13]] : vector<1x4xindex> to vector<4xindex>
-// CHECK:           %[[VAL_19:.*]] = vector.extractelement %[[VAL_18]]{{\[}}%[[VAL_7]] : i32] : vector<4xindex>
-// CHECK:           %[[VAL_20:.*]] = vector.extractelement %[[VAL_17]]{{\[}}%[[VAL_7]] : i32] : vector<4xindex>
-// CHECK:           %[[VAL_21:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_19]], %[[VAL_10]], %[[VAL_20]]], %[[VAL_8]] {in_bounds = [true, true]} : tensor<45x80x16xf32>, vector<1x4xf32>
-// CHECK:           %[[VAL_22:.*]] = vector.transfer_write %[[VAL_21]], %[[VAL_5]]{{\[}}%[[VAL_9]], %[[VAL_9]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
-
-transform.sequence failures(propagate) {
- ^bb1(%arg1: !pdl.operation):
-   %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!pdl.operation) -> !pdl.operation
-   %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation
-   %2 = transform.structured.vectorize %1 { vectorize_nd_extract }
- }
-
 // -----
 
 #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
@@ -2171,56 +2119,6 @@ transform.sequence failures(propagate) {
 
 // -----
 
-#map1 = affine_map<(d0, d1) -> (d0, d1)>
-func.func @extract_masked_vectorize(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
-  %c0 = arith.constant 1 : index
-  %c1 = arith.constant 2 : index
-  %2 = linalg.generic {
-    indexing_maps = [#map1],
-    iterator_types = ["parallel", "parallel"]
-  } outs(%arg1 : tensor<?x?xf32>) {
-  ^bb0(%arg3: f32):
-    %7 = tensor.extract %arg0[%c0, %c1] : tensor<?x?xf32>
-    linalg.yield %7 : f32
-  } -> tensor<?x?xf32>
-  return %2 : tensor<?x?xf32>
-}
-
-// CHECK-LABEL:   func.func @extract_masked_vectorize(
-// CHECK-SAME:                                        %[[VAL_0:.*]]: tensor<?x?xf32>,
-// CHECK-SAME:                                        %[[VAL_1:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
-// CHECK:           %[[VAL_2:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_3:.*]] = arith.constant 2 : index
-// CHECK:           %[[VAL_4:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_5:.*]] = tensor.dim %[[VAL_1]], %[[VAL_4]] : tensor<?x?xf32>
-// CHECK:           %[[VAL_6:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_7:.*]] = tensor.dim %[[VAL_1]], %[[VAL_6]] : tensor<?x?xf32>
-// CHECK:           %[[VAL_8:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK:           %[[VAL_10:.*]] = vector.create_mask %[[VAL_5]], %[[VAL_7]] : vector<3x3xi1>
-// CHECK:           %[[VAL_11:.*]] = vector.mask %[[VAL_10]] { vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_8]], %[[VAL_8]]], %[[VAL_9]] {in_bounds = [true, true]} : tensor<?x?xf32>, vector<3x3xf32> } : vector<3x3xi1> -> vector<3x3xf32>
-// CHECK:           %[[VAL_12:.*]] = arith.constant dense<true> : vector<3x3xi1>
-// CHECK:           %[[VAL_13:.*]] = arith.constant dense<0.000000e+00> : vector<3x3xf32>
-// CHECK:           %[[VAL_14:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_15:.*]] = arith.constant dense<1> : vector<3x3xindex>
-// CHECK:           %[[VAL_16:.*]] = arith.constant 1 : index
-// CHECK:           %[[VAL_17:.*]] = tensor.dim %[[VAL_0]], %[[VAL_16]] : tensor<?x?xf32>
-// CHECK:           %[[VAL_18:.*]] = vector.broadcast %[[VAL_17]] : index to vector<3x3xindex>
-// CHECK:           %[[VAL_19:.*]] = arith.muli %[[VAL_15]], %[[VAL_18]] : vector<3x3xindex>
-// CHECK:           %[[VAL_20:.*]] = arith.constant dense<2> : vector<3x3xindex>
-// CHECK:           %[[VAL_21:.*]] = arith.addi %[[VAL_20]], %[[VAL_19]] : vector<3x3xindex>
-// CHECK:           %[[VAL_22:.*]] = vector.mask %[[VAL_10]] { vector.gather %[[VAL_0]]{{\[}}%[[VAL_14]], %[[VAL_14]]] {{\[}}%[[VAL_21]]], %[[VAL_12]], %[[VAL_13]] : tensor<?x?xf32>, vector<3x3xindex>, vector<3x3xi1>, vector<3x3xf32> into vector<3x3xf32> } : vector<3x3xi1> -> vector<3x3xf32>
-// CHECK:           %[[VAL_23:.*]] = arith.constant 0 : index
-// CHECK:           %[[VAL_24:.*]] = vector.mask %[[VAL_10]] { vector.transfer_write %[[VAL_22]], %[[VAL_1]]{{\[}}%[[VAL_23]], %[[VAL_23]]] {in_bounds = [true, true]} : vector<3x3xf32>, tensor<?x?xf32> } : vector<3x3xi1> -> tensor<?x?xf32>
-
-transform.sequence failures(propagate) {
- ^bb1(%arg1: !pdl.operation):
-   %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!pdl.operation) -> !pdl.operation
-   transform.structured.masked_vectorize %0 vector_sizes [3, 3] { vectorize_nd_extract }
- }
-
-// -----
-
 func.func @do_not_generate_masks(%arg0: tensor<8x32xf32>,
                                  %arg1: tensor<8x32xf32>,
                                  %arg2: tensor<8x32xf32>) -> tensor<8x32xf32> {


        


More information about the Mlir-commits mailing list