[Mlir-commits] [mlir] e28bbfe - Revert "[mlir][linalg] Vectorize tensor.extract using contiguous loads"
Benjamin Kramer
llvmlistbot at llvm.org
Tue Feb 28 04:33:22 PST 2023
Author: Benjamin Kramer
Date: 2023-02-28T13:33:11+01:00
New Revision: e28bbfea5d482c1825b1799c57aedff4e0116619
URL: https://github.com/llvm/llvm-project/commit/e28bbfea5d482c1825b1799c57aedff4e0116619
DIFF: https://github.com/llvm/llvm-project/commit/e28bbfea5d482c1825b1799c57aedff4e0116619.diff
LOG: Revert "[mlir][linalg] Vectorize tensor.extract using contiguous loads"
This reverts commit 89b144ece330b363713bec369d2d89dc85f715f5. See
https://reviews.llvm.org/D141998 for a test case where this goes wrong.
Added:
mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
Modified:
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/vectorization.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 4ab82825d1a7..11ba3712d9d1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -611,11 +611,11 @@ static Value calculateGatherOffset(RewriterBase &rewriter,
const size_t numIndices = extractOp.getIndices().size();
for (size_t i = 1; i < numIndices; i++) {
- Value dimIdx = rewriter.create<arith::ConstantIndexOp>(loc, i);
-
auto dimSize = broadcastIfNeeded(
rewriter,
- rewriter.create<tensor::DimOp>(loc, extractOp.getTensor(), dimIdx),
+ rewriter.create<arith::ConstantIndexOp>(
+ loc,
+ extractOp.getTensor().getType().cast<ShapedType>().getDimSize(i)),
indexVecType.getShape());
offset = rewriter.create<arith::MulIOp>(loc, offset, dimSize);
@@ -630,143 +630,6 @@ static Value calculateGatherOffset(RewriterBase &rewriter,
return offset;
}
-enum VectorMemoryAccessKind {
- // TODO: ScalarBroadcast,
- Contiguous,
- Gather
-};
-
-/// Check whether /p val can be used for calculating an index for a contiguous
-/// load operation, i.e. whether /p val:
-/// * is invariant with respect to /p linalgOp, i.e. whether it remains
-/// constant for all iterations, and
-/// * increments with the loop iterator (when /p strideZero is false) or is
-/// not affected by the loop indices (/p strideZero is true).
-static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, size_t dim,
- bool strideZero) {
- auto *block = linalgOp.getBlock();
-
- // Bail out if this is a block argument for this linalg.generic Op.
- // TODO: We could try analysing the corresponding affine map here.
- if (val.dyn_cast<BlockArgument>())
- return llvm::all_of(block->getArguments(),
- [&val](Value v) { return (v != val); });
-
- Operation *defOp = val.getDefiningOp();
- assert(defOp && "This is neither a block argument nor an operation result");
-
- // Given the assumption on the shape of the target tensor, index Op is
- // either:
- // * constant (for non-trailing dims), or
- // * increments with stride one together with the trailing dimension
- // Both cases are fine for contigious loads.
- if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp))
- return strideZero ? (indexOp.getDim() != dim) : (indexOp.getDim() == dim);
-
- auto *ancestor = block->findAncestorOpInBlock(*defOp);
-
- // Values define outside `linalgOp`.
- if (!ancestor)
- return true;
-
- // Values defined inside `linalgOp`, which are constant.
- if (dyn_cast<arith::ConstantOp>(ancestor))
- return true;
-
- bool result = true;
- for (auto op : ancestor->getOperands())
- result &= isContiguousLoadIdx(linalgOp, op, dim, strideZero);
-
- return result;
-}
-
-/// Check whether the calculation of \p val is based on linalg.index Op with
-/// the dim attribute matching \p dim.
-static bool isBasedOnIndexOp(LinalgOp &linalgOp, Value &val, size_t dim) {
- auto *block = linalgOp.getBlock();
- auto targetShape = linalgOp.getStaticLoopRanges();
-
- if (val.isa<BlockArgument>())
- return false;
-
- Operation *defOp = val.getDefiningOp();
- assert(defOp && "This is neither a block argument nor an operation result");
-
- if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp))
- return (indexOp.getDim() == dim);
-
- auto *ancestor = block->findAncestorOpInBlock(*defOp);
-
- if (!ancestor)
- return false;
-
- bool result = false;
- for (auto op : ancestor->getOperands())
- result |= isBasedOnIndexOp(linalgOp, op, dim);
-
- return result;
-}
-
-/// Check whether \p extractOp would be a gather or a contiguous load Op after
-/// vectorising \p linalgOp. Note that it is always safe to use gather load
-/// operations for contiguous loads (albeit slow), but not vice-versa. When in
-/// doubt, bail out and assume that \p extractOp is a gather load.
-static VectorMemoryAccessKind
-getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
- LinalgOp &linalgOp) {
-
- auto targetShape = linalgOp.getStaticLoopRanges();
-
- // Assume that it's a gather load when reading _into_:
- // * an n-D vector, like`tensor<1x2x4xi32` or`tensor<2x1x4xi32>`, or
- // * a 1-D vector with the trailing dim equal 1, e.g. `tensor<1x4x1xi32`.
- // TODO: Relax these conditions.
- if ((llvm::count_if(targetShape,
- [](int64_t dimSize) { return dimSize > 1; }) != 1) ||
- targetShape.back() == 1)
- return VectorMemoryAccessKind::Gather;
-
- auto inputShape = extractOp.getTensor().getType().cast<ShapedType>();
-
- // Assume that it's a gather load when reading _from_ a tensor for which the
- // trailing dimension is 1, e.g. `tensor<1x4x1xi32>`.
- // TODO: Relax this condition.
- if (inputShape.getShape().back() == 1)
- return VectorMemoryAccessKind::Gather;
-
- bool isContiguous = true;
-
- // Iterate over all indices. Analyze whether the way each index is calculate
- // is suitable for contiguous load operations (e.g. loop invariant).
- auto indices = extractOp.getIndices();
- for (auto [i, indexVal] : llvm::enumerate(indices)) {
- if (inputShape.getShape()[i] == 1) {
- // This extractOp index must be a loop-invariant constant
- continue;
- }
-
- auto extractOpBottomIdx = indices.size() - 1;
- auto strideOneDim = targetShape.size() - 1;
- bool strideZero = (i != extractOpBottomIdx);
- isContiguous &=
- isContiguousLoadIdx(linalgOp, indexVal, strideOneDim, strideZero);
- }
-
- // The calculation of the trailing index must include the loop index. Given
- // the assumption on the output tensor (which is defined by the iteration
- // space), only the trailing dim matters.
- auto extractOpTrailingIdx = indices.back();
- isContiguous &=
- isBasedOnIndexOp(linalgOp, extractOpTrailingIdx, targetShape.size() - 1);
-
- if (isContiguous) {
- LDBG("Found contigous load: " << extractOp);
- return VectorMemoryAccessKind::Contiguous;
- }
-
- return VectorMemoryAccessKind::Gather;
-}
-
/// Helper function to vectorize the tensor.extract operations. Returns
/// VectorizationStatus::NewOp to signal the vectorization algorithm that it
/// should map the produced operations. This function is meant to be used as a
@@ -797,64 +660,15 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
extractOp.getIndices().size(),
rewriter.create<arith::ConstantIndexOp>(loc, 0));
- VectorMemoryAccessKind memAccessKind =
- getTensorExtractMemoryAccessPattern(extractOp, linalgOp);
-
- // 1. Handle gather access
- if (memAccessKind == VectorMemoryAccessKind::Gather) {
- Value offset = calculateGatherOffset(rewriter, extractOp, bvm, targetShape);
-
- // Generate the gather load
- Operation *gatherOp = rewriter.create<vector::GatherOp>(
- loc, resultType, extractOp.getTensor(), baseIndices, offset,
- maskConstantOp, passThruConstantOp);
- gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
-
- LDBG("Vectorised as gather load: " << extractOp);
- return VectorizationResult{VectorizationStatus::NewOp, gatherOp};
- }
-
- // 2. Handle contiguous access.
- SmallVector<Value> transferReadIdxs;
- auto resTrailingDim = resultType.getShape().back();
- auto zero = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32Type(), rewriter.getZeroAttr(rewriter.getI32Type()));
-
- // Collect indices for `vector.transfer_read`. At this point, the indices will
- // either be scalars or would have been broadcast to vectors matching the
- // result type. For indices that are vectors, there are two options:
- // * for non-trailing indices, all elements are identical (contiguous
- // loads are identified by looking for non-trailing indices that are
- // invariant with respect to the corresponding linalg.generic), or
- // * for trailing indices, the index vector will contain values with stride
- // one, but for `vector.transfer_read` only the first (i.e. 0th) index is
- // needed.
- // This means that
- // * for scalar indices - just re-use it,
- // * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
- // (0th) element and use that.
- for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
- auto idx = bvm.lookup(extractOp.getIndices()[i]);
- if (idx.getType().isIndex()) {
- transferReadIdxs.push_back(idx);
- continue;
- }
-
- auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
- loc, VectorType::get({resTrailingDim}, rewriter.getIndexType()),
- bvm.lookup(extractOp.getIndices()[i]));
- transferReadIdxs.push_back(
- rewriter.create<vector::ExtractElementOp>(loc, indexAs1dVector, zero));
- }
-
- // `tensor.extract_element` is always in-bounds, hence the following holds.
- SmallVector<bool> inBounds(resultType.getRank(), true);
+ Value offset = calculateGatherOffset(rewriter, extractOp, bvm, targetShape);
- auto transferReadOp = rewriter.create<vector::TransferReadOp>(
- loc, resultType, extractOp.getTensor(), transferReadIdxs, inBounds);
+ // Generate the gather load
+ Operation *gatherOp = rewriter.create<vector::GatherOp>(
+ loc, resultType, extractOp.getTensor(), baseIndices, offset,
+ maskConstantOp, passThruConstantOp);
+ gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
- LDBG("Vectorised as contiguous load: " << extractOp);
- return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
+ return VectorizationResult{VectorizationStatus::NewOp, gatherOp};
}
/// Emit reduction operations if the shapes of the value to reduce is
diff erent
diff --git a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
new file mode 100644
index 000000000000..aa1d6e34657e
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
@@ -0,0 +1,29 @@
+// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file -verify-diagnostics
+
+// Masked vectorisation of `tensor.extract`:
+// * requires the `{ vectorize_nd_extract }` attribute,
+// * has not been implemented yet (hence the attribute is absent).
+// TOOD: Implement masked vectorization for `tensor.extract`
+
+#map1 = affine_map<(d0, d1) -> (d0, d1)>
+func.func @extract_masked_vectorize(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %c0 = arith.constant 1 : index
+ %c1 = arith.constant 2 : index
+ // expected-error at +1 {{failed to vectorize op}}
+ %2 = linalg.generic {
+ indexing_maps = [#map1],
+ iterator_types = ["parallel", "parallel"]
+ } outs(%arg1 : tensor<?x?xf32>) {
+ ^bb0(%arg3: f32):
+ %7 = tensor.extract %arg0[%c0, %c1] : tensor<?x?xf32>
+ linalg.yield %7 : f32
+ } -> tensor<?x?xf32>
+ return %2 : tensor<?x?xf32>
+}
+
+
+transform.sequence failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+ transform.structured.masked_vectorize %0 vector_sizes [3, 3]
+ }
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index a43cf9da514f..a43fd7a9fc5c 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -1584,6 +1584,7 @@ func.func @vectorize_nd_tensor_extract_constant_idx(%arg0: tensor<3x3xf32>, %arg
iterator_types = ["parallel", "parallel", "parallel"]
} outs(%arg2 : tensor<1x1x3xf32>) {
^bb0(%arg4: f32):
+ %3 = linalg.index 2 : index
%7 = tensor.extract %arg0[%c0, %c1] : tensor<3x3xf32>
linalg.yield %7 : f32
} -> tensor<1x1x3xf32>
@@ -1612,7 +1613,7 @@ transform.sequence failures(propagate) {
// -----
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-func.func @vectorize_nd_tensor_extract_transfer_read_basic(%arg0: tensor<3x3x3xf32>, %arg2: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> {
+func.func @vectorize_nd_tensor_extract_idx_from_iteration_index(%arg0: tensor<3x3x3xf32>, %arg2: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> {
%1 = linalg.generic {
indexing_maps = [#map1],
iterator_types = ["parallel", "parallel", "parallel"]
@@ -1627,19 +1628,16 @@ func.func @vectorize_nd_tensor_extract_transfer_read_basic(%arg0: tensor<3x3x3xf
return %1 : tensor<1x1x3xf32>
}
-// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_read_basic
+// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_idx_from_iteration_index
// CHECK-SAME: %[[ARG0:.*]]: tensor<3x3x3xf32>
// CHECK-SAME: %[[ARG1:.*]]: tensor<1x1x3xf32>
-// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1x1x3xindex>
-// CHECK: %[[C0_i32:.*]] = arith.constant 0 : i32
+// CHECK: %[[INDICES:.*]] = arith.constant dense<[0, 1, 2]> : vector<3xindex>
+// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<1x1x3xi1>
+// CHECK: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x3xf32>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[CST_0:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[IDX_VEC0:.*]] = vector.shape_cast %[[CST]] : vector<1x1x3xindex> to vector<3xindex>
-// CHECK: %[[IDX1:.*]] = vector.extractelement %[[IDX_VEC0]][%[[C0_i32]] : i32] : vector<3xindex>
-// CHECK: %[[IDX_VEC:.*]] = vector.shape_cast %[[CST]] : vector<1x1x3xindex> to vector<3xindex>
-// CHECK: %[[IDX2:.*]] = vector.extractelement %[[IDX_VEC]][%[[C0_i32]] : i32] : vector<3xindex>
-// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[IDX1]], %[[IDX2]], %[[C0:.*]]], %[[CST_0]] {in_bounds = [true, true, true]} : tensor<3x3x3xf32>, vector<1x1x3xf32>
-// CHECK: vector.transfer_write %[[READ]], %[[ARG1]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x3xf32>, tensor<1x1x3xf32>
+// CHECK: %[[B:.*]] = vector.broadcast %[[INDICES]] : vector<3xindex> to vector<1x1x3xindex>
+// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[B]]], %[[MASK]], %[[PASSTHRU]] : tensor<3x3x3xf32>, vector<1x1x3xindex>, vector<1x1x3xi1>, vector<1x1x3xf32> into vector<1x1x3xf32>
+// CHECK: vector.transfer_write %[[GATHER]]
transform.sequence failures(propagate) {
^bb1(%arg1: !pdl.operation):
@@ -1648,56 +1646,6 @@ transform.sequence failures(propagate) {
%2 = transform.structured.vectorize %1 { vectorize_nd_extract }
}
- // -----
-
-func.func @vectorize_nd_tensor_extract_transfer_read_complex(%6: tensor<45x80x16xf32>, %arg0: index, %arg2: index, %arg1: index, %arg4: index, %extracted_slice : tensor<1x4xf32>) -> tensor<1x4xf32> {
- %c79 = arith.constant 79 : index
- %25 = linalg.generic {
- indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
- iterator_types = ["parallel", "parallel"]
- } outs(%extracted_slice : tensor<1x4xf32>) {
- ^bb0(%out: f32):
- %26 = linalg.index 0 : index
- %27 = arith.addi %arg0, %26 : index
- %28 = arith.addi %27, %arg2 : index
- %29 = linalg.index 1 : index
- %30 = arith.addi %arg1, %29 : index
- %31 = arith.addi %30, %arg4 : index
- %extracted = tensor.extract %6[%28, %c79, %31] : tensor<45x80x16xf32>
- linalg.yield %extracted : f32
- } -> tensor<1x4xf32>
- return %25 : tensor<1x4xf32>
-}
-
-// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_read_complex
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<45x80x16xf32>,
-// CHECK-SAME: {{.*}}: index,
-// CHECK-SAME: %[[VAL_5:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
-// CHECK: %[[VAL_6:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
-// CHECK: %[[VAL_7:.*]] = arith.constant 0 : i32
-// CHECK: %[[VAL_8:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[VAL_9:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_10:.*]] = arith.constant 79 : index
-// CHECK: %[[VAL_11:.*]] = vector.broadcast %{{.*}} : index to vector<1x4xindex>
-// CHECK: %[[VAL_12:.*]] = vector.broadcast %{{.*}} : index to vector<1x4xindex>
-// CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_11]], %[[VAL_12]] : vector<1x4xindex>
-// CHECK: %[[VAL_14:.*]] = vector.broadcast %{{.*}} : index to vector<4xindex>
-// CHECK: %[[VAL_15:.*]] = arith.addi %[[VAL_14]], %[[VAL_6]] : vector<4xindex>
-// CHECK: %[[VAL_16:.*]] = vector.broadcast %{{.*}} : index to vector<4xindex>
-// CHECK: %[[VAL_17:.*]] = arith.addi %[[VAL_15]], %[[VAL_16]] : vector<4xindex>
-// CHECK: %[[VAL_18:.*]] = vector.shape_cast %[[VAL_13]] : vector<1x4xindex> to vector<4xindex>
-// CHECK: %[[VAL_19:.*]] = vector.extractelement %[[VAL_18]]{{\[}}%[[VAL_7]] : i32] : vector<4xindex>
-// CHECK: %[[VAL_20:.*]] = vector.extractelement %[[VAL_17]]{{\[}}%[[VAL_7]] : i32] : vector<4xindex>
-// CHECK: %[[VAL_21:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_19]], %[[VAL_10]], %[[VAL_20]]], %[[VAL_8]] {in_bounds = [true, true]} : tensor<45x80x16xf32>, vector<1x4xf32>
-// CHECK: %[[VAL_22:.*]] = vector.transfer_write %[[VAL_21]], %[[VAL_5]]{{\[}}%[[VAL_9]], %[[VAL_9]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
-
-transform.sequence failures(propagate) {
- ^bb1(%arg1: !pdl.operation):
- %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!pdl.operation) -> !pdl.operation
- %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation
- %2 = transform.structured.vectorize %1 { vectorize_nd_extract }
- }
-
// -----
#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
@@ -2171,56 +2119,6 @@ transform.sequence failures(propagate) {
// -----
-#map1 = affine_map<(d0, d1) -> (d0, d1)>
-func.func @extract_masked_vectorize(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
- %c0 = arith.constant 1 : index
- %c1 = arith.constant 2 : index
- %2 = linalg.generic {
- indexing_maps = [#map1],
- iterator_types = ["parallel", "parallel"]
- } outs(%arg1 : tensor<?x?xf32>) {
- ^bb0(%arg3: f32):
- %7 = tensor.extract %arg0[%c0, %c1] : tensor<?x?xf32>
- linalg.yield %7 : f32
- } -> tensor<?x?xf32>
- return %2 : tensor<?x?xf32>
-}
-
-// CHECK-LABEL: func.func @extract_masked_vectorize(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xf32>,
-// CHECK-SAME: %[[VAL_1:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
-// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index
-// CHECK: %[[VAL_3:.*]] = arith.constant 2 : index
-// CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_1]], %[[VAL_4]] : tensor<?x?xf32>
-// CHECK: %[[VAL_6:.*]] = arith.constant 1 : index
-// CHECK: %[[VAL_7:.*]] = tensor.dim %[[VAL_1]], %[[VAL_6]] : tensor<?x?xf32>
-// CHECK: %[[VAL_8:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[VAL_10:.*]] = vector.create_mask %[[VAL_5]], %[[VAL_7]] : vector<3x3xi1>
-// CHECK: %[[VAL_11:.*]] = vector.mask %[[VAL_10]] { vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_8]], %[[VAL_8]]], %[[VAL_9]] {in_bounds = [true, true]} : tensor<?x?xf32>, vector<3x3xf32> } : vector<3x3xi1> -> vector<3x3xf32>
-// CHECK: %[[VAL_12:.*]] = arith.constant dense<true> : vector<3x3xi1>
-// CHECK: %[[VAL_13:.*]] = arith.constant dense<0.000000e+00> : vector<3x3xf32>
-// CHECK: %[[VAL_14:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_15:.*]] = arith.constant dense<1> : vector<3x3xindex>
-// CHECK: %[[VAL_16:.*]] = arith.constant 1 : index
-// CHECK: %[[VAL_17:.*]] = tensor.dim %[[VAL_0]], %[[VAL_16]] : tensor<?x?xf32>
-// CHECK: %[[VAL_18:.*]] = vector.broadcast %[[VAL_17]] : index to vector<3x3xindex>
-// CHECK: %[[VAL_19:.*]] = arith.muli %[[VAL_15]], %[[VAL_18]] : vector<3x3xindex>
-// CHECK: %[[VAL_20:.*]] = arith.constant dense<2> : vector<3x3xindex>
-// CHECK: %[[VAL_21:.*]] = arith.addi %[[VAL_20]], %[[VAL_19]] : vector<3x3xindex>
-// CHECK: %[[VAL_22:.*]] = vector.mask %[[VAL_10]] { vector.gather %[[VAL_0]]{{\[}}%[[VAL_14]], %[[VAL_14]]] {{\[}}%[[VAL_21]]], %[[VAL_12]], %[[VAL_13]] : tensor<?x?xf32>, vector<3x3xindex>, vector<3x3xi1>, vector<3x3xf32> into vector<3x3xf32> } : vector<3x3xi1> -> vector<3x3xf32>
-// CHECK: %[[VAL_23:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_24:.*]] = vector.mask %[[VAL_10]] { vector.transfer_write %[[VAL_22]], %[[VAL_1]]{{\[}}%[[VAL_23]], %[[VAL_23]]] {in_bounds = [true, true]} : vector<3x3xf32>, tensor<?x?xf32> } : vector<3x3xi1> -> tensor<?x?xf32>
-
-transform.sequence failures(propagate) {
- ^bb1(%arg1: !pdl.operation):
- %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!pdl.operation) -> !pdl.operation
- transform.structured.masked_vectorize %0 vector_sizes [3, 3] { vectorize_nd_extract }
- }
-
-// -----
-
func.func @do_not_generate_masks(%arg0: tensor<8x32xf32>,
%arg1: tensor<8x32xf32>,
%arg2: tensor<8x32xf32>) -> tensor<8x32xf32> {
More information about the Mlir-commits
mailing list