[Mlir-commits] [mlir] [mlir][linalg] Upgrade vectorisation of tensor.extract (PR #100582)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Thu Jul 25 07:52:12 PDT 2024
https://github.com/banach-space created https://github.com/llvm/llvm-project/pull/100582
This PR removes the assumption that reading from a dynamic tensor is
always a gather load:
```mlir
%extracted = tensor.extract %src[%c79, %3] : tensor<?x?xf32>
```
That assumption was originally introduced to simplify the implementation
and to reduce the number of cases to consider. Now that the
vectorisation of `tensor.extract` has been around for > 1 year and has
been quite stable, we can safely relax it.
This is a relatively small change - rather than using the parent linalg
Op to infer the target output shape (not possible with dynamic shapes),
the vectorizer will use the (previously constructed) output vector
shape instead.
As expected, the following test required updating (`vector.gather` ->
`vector.transfer_read`):
* @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous
Similar test for scalable vectors is also added.
>From aae7571a8213a429a64f7918c0d0c560d5efdead Mon Sep 17 00:00:00 2001
From: Andrzej Warzynski <andrzej.warzynski at arm.com>
Date: Tue, 23 Jul 2024 19:18:39 +0100
Subject: [PATCH] [mlir][linalg] Upgrade vectorisation of tensor.extract
This PR removes the assumption that reading from a dynamic tensor is
always a gather load:
```mlir
%extracted = tensor.extract %src[%c79, %3] : tensor<?x?xf32>
```
That assumption was originally introduced to simplify the implementation
and to reduce the number of cases to consider. Now that the
vectorisation of `tensor.extract` has been around for > 1 year and has
been quite stable, we can safely relax it.
This is a relatively small change - rather than using the parent linalg
Op to infer the target output shape (not possible with dynamic shapes),
the vectorizer will use the (previously constructed) output vector
shape instead.
As expected, the following test required updating (`vector.gather` ->
`vector.transfer_read`):
* @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous
Similar test for scalable vectors is also added.
---
.../Linalg/Transforms/Vectorization.cpp | 58 ++++----
.../Linalg/vectorization-scalable.mlir | 9 +-
.../vectorize-tensor-extract-masked.mlir | 129 +++++++++++++-----
3 files changed, 123 insertions(+), 73 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 9185663799e52..5c8d4a00bc35f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -808,14 +808,13 @@ static Value calculateGatherOffset(RewriterBase &rewriter,
enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather };
-/// Checks whether /p val can be used for calculating a loop invariant index.
-static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
+/// Checks whether `val` can be used for calculating a loop invariant index.
+static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val, VectorType resType) {
- auto targetShape = linalgOp.getStaticLoopRanges();
- assert(((llvm::count_if(targetShape,
+ assert(((llvm::count_if(resType.getShape(),
[](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
"n-D vectors are not yet supported");
- assert(targetShape.back() != 1 &&
+ assert(resType.getShape().back() != 1 &&
"1-D vectors with the trailing dim eqaual 1 are not yet supported");
// Blocks outside _this_ linalg.generic are effectively loop invariant.
@@ -849,7 +848,7 @@ static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
bool result = true;
for (auto op : ancestor->getOperands())
- result &= isLoopInvariantIdx(linalgOp, op);
+ result &= isLoopInvariantIdx(linalgOp, op, resType);
return result;
}
@@ -871,13 +870,12 @@ static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
/// where <dim> is the trailing dim of the iteration space. \p foundIndexOp is
/// updated to `true` when such an op is found.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
- bool &foundIndexOp) {
+ bool &foundIndexOp, VectorType resType) {
- auto targetShape = linalgOp.getStaticLoopRanges();
- assert(((llvm::count_if(targetShape,
+ assert(((llvm::count_if(resType.getShape(),
[](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
"n-D vectors are not yet supported");
- assert(targetShape.back() != 1 &&
+ assert(resType.getShape().back() != 1 &&
"1-D vectors with the trailing dim 1 are not yet supported");
// Blocks outside _this_ linalg.generic are effectively loop invariant.
@@ -912,46 +910,40 @@ static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
bool result = false;
for (auto op : ancestor->getOperands())
- result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp);
+ result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp, resType);
return result;
}
/// Infer the memory access pattern for the input ExtractOp
///
-/// Based on the operation shapes and indices (usually based on the iteration
-/// space of the parent `linalgOp` operation), decides whether the input
-/// ExtractOp is a contiguous load (including a broadcast of a scalar) or a
-/// gather load.
+/// Based on the ExtratOp result shape and the access indices, decides whether
+/// this Op corresponds to a contiguous load (including a broadcast of a scalar)
+/// or a gather load. When analysing the ExtractOp indices (to identify
+/// contiguous laods), this method looks for "loop" invariant indices (e.g.
+/// block arguments) and indices that change linearly (e.g. via `linalg.index`
+/// Op).
///
/// Note that it is always safe to use gather load operations for contiguous
/// loads (albeit slow), but not vice-versa. When in doubt, bail out and assume
/// that `extractOp` is a gather load.
static VectorMemoryAccessKind
getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
- LinalgOp &linalgOp) {
+ LinalgOp &linalgOp, VectorType resType) {
- auto targetShape = linalgOp.getStaticLoopRanges();
auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
- // 0.1 Is this a 0-D vector? If yes then this is a scalar broadcast.
+ // 0. Is this a 0-D vector? If yes then this is a scalar broadcast.
if (inputShape.getShape().empty())
return VectorMemoryAccessKind::ScalarBroadcast;
- // 0.2 In the case of dynamic shapes just bail-out and assume that it's a
- // gather load.
- // TODO: Relax this condition.
- if (linalgOp.hasDynamicShape())
- return VectorMemoryAccessKind::Gather;
-
// 1. Assume that it's a gather load when reading _into_:
- // * an n-D "vector", like `tensor<1x2x4xi32` or `tensor<2x1x4xi32>`, or
- // * a 1-D "vector" with the trailing dim equal 1, e.g. `tensor<1x4x1xi32`.
+ // * an n-D "vector", like `vector<1x2x4xi32` or `vector<2x1x4xi32>`, or
+ // * a 1-D "vector" with the trailing dim equal 1, e.g. `vector<1x4x1xi32>`.
// TODO: Relax these conditions.
- // FIXME: This condition assumes non-dynamic sizes.
- if ((llvm::count_if(targetShape,
+ if ((llvm::count_if(resType.getShape(),
[](int64_t dimSize) { return dimSize > 1; }) != 1) ||
- targetShape.back() == 1)
+ resType.getShape().back() == 1)
return VectorMemoryAccessKind::Gather;
// 2. Assume that it's a gather load when reading _from_ a tensor for which
@@ -972,7 +964,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
if (inputShape.getShape()[i] == 1)
continue;
- leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal);
+ leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal, resType);
}
if (!leadingIdxsLoopInvariant) {
@@ -989,7 +981,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
// 4a. Scalar broadcast load
// If the trailing index is loop invariant then this is a scalar load.
if (leadingIdxsLoopInvariant &&
- isLoopInvariantIdx(linalgOp, extractOpTrailingIdx)) {
+ isLoopInvariantIdx(linalgOp, extractOpTrailingIdx, resType)) {
LDBG("Found scalar broadcast load: " << extractOp);
return VectorMemoryAccessKind::ScalarBroadcast;
@@ -1001,7 +993,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
// This is what the following bool captures.
bool foundIndexOp = false;
bool isContiguousLoad =
- isContiguousLoadIdx(linalgOp, extractOpTrailingIdx, foundIndexOp);
+ isContiguousLoadIdx(linalgOp, extractOpTrailingIdx, foundIndexOp, resType);
isContiguousLoad &= foundIndexOp;
if (isContiguousLoad) {
@@ -1042,7 +1034,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
rewriter.create<arith::ConstantIndexOp>(loc, 0));
VectorMemoryAccessKind memAccessKind =
- getTensorExtractMemoryAccessPattern(extractOp, linalgOp);
+ getTensorExtractMemoryAccessPattern(extractOp, linalgOp, resultType);
// 1. Handle gather access
if (memAccessKind == VectorMemoryAccessKind::Gather) {
diff --git a/mlir/test/Dialect/Linalg/vectorization-scalable.mlir b/mlir/test/Dialect/Linalg/vectorization-scalable.mlir
index 4ee3088cc3778..c3a30e3ee209e 100644
--- a/mlir/test/Dialect/Linalg/vectorization-scalable.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-scalable.mlir
@@ -162,17 +162,14 @@ func.func @vectorize_linalg_index(%arg0: tensor<3x3x?xf32>, %arg1: tensor<1x1x?x
// CHECK-LABEL: @vectorize_linalg_index
// CHECK-SAME: %[[SRC:.*]]: tensor<3x3x?xf32>, %[[DST:.*]]: tensor<1x1x?xf32>
-// CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x[4]xf32>
-// CHECK-DAG: %[[MASK:.*]] = arith.constant dense<true> : vector<1x1x[4]xi1>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[DST_DIM2:.*]] = tensor.dim %[[DST]], %[[C2]] : tensor<1x1x?xf32>
-// CHECK: %[[DST_MASK:.*]] = vector.create_mask %[[C1]], %[[C1]], %[[DST_DIM2]] : vector<1x1x[4]xi1>
+// CHECK: %[[MASK:.*]] = vector.create_mask %[[C1]], %[[C1]], %[[DST_DIM2]] : vector<1x1x[4]xi1>
// CHECK: %[[INDEX_VEC:.*]] = vector.step : vector<[4]xindex>
-// CHECK: %[[INDEX_VEC_BCAST:.*]] = vector.broadcast %[[INDEX_VEC]] : vector<[4]xindex> to vector<1x1x[4]xindex>
-// CHECK: %[[GATHER:.*]] = vector.mask %[[DST_MASK]] { vector.gather %[[SRC]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {{\[}}%[[INDEX_VEC_BCAST]]], %[[MASK]], %[[PASSTHRU]] : tensor<3x3x?xf32>, vector<1x1x[4]xindex>, vector<1x1x[4]xi1>, vector<1x1x[4]xf32> into vector<1x1x[4]xf32> } : vector<1x1x[4]xi1> -> vector<1x1x[4]xf32>
-// CHECK: %[[OUT:.*]] = vector.mask %[[DST_MASK]] { vector.transfer_write %[[GATHER]], %[[DST]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x[4]xf32>, tensor<1x1x?xf32> } : vector<1x1x[4]xi1> -> tensor<1x1x?xf32>
+// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]][%c0, %c0, %2], %cst {in_bounds = [true, true, true]} : tensor<3x3x?xf32>, vector<1x1x[4]xf32> } : vector<1x1x[4]xi1> -> vector<1x1x[4]xf32>
+// CHECK: %[[OUT:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[DST]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x[4]xf32>, tensor<1x1x?xf32> } : vector<1x1x[4]xi1> -> tensor<1x1x?xf32>
// CHECK: return %[[OUT]] : tensor<1x1x?xf32>
module attributes {transform.with_named_sequence} {
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir
index 964565620fd01..31a754d934368 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir
@@ -120,52 +120,54 @@ module attributes {transform.with_named_sequence} {
// -----
-func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous(%6: tensor<?x?xf32>, %arg0: index, %extracted_slice : tensor<?x?xf32>) -> tensor<?x?xf32> {
+func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous(
+ %src: tensor<?x?xf32>,
+ %output : tensor<?x?xf32>,
+ %idx: index) -> tensor<?x?xf32> {
+
%c79 = arith.constant 79 : index
%1 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]
- } outs(%extracted_slice : tensor<?x?xf32>) {
+ } outs(%output : tensor<?x?xf32>) {
^bb0(%out: f32):
%2 = linalg.index 1 : index
- %3 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %arg0)
- %extracted = tensor.extract %6[%c79, %3] : tensor<?x?xf32>
+ %3 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %idx)
+ %extracted = tensor.extract %src[%c79, %3] : tensor<?x?xf32>
linalg.yield %extracted : f32
} -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
// CHECK-LABEL: func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xf32>,
-// CHECK-SAME: %[[VAL_1:.*]]: index,
-// CHECK-SAME: %[[VAL_2:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
-// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 79 : index
-// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_2]], %[[VAL_4]] : tensor<?x?xf32>
-// CHECK: %[[VAL_6:.*]] = arith.constant 1 : index
-// CHECK: %[[VAL_7:.*]] = tensor.dim %[[VAL_2]], %[[VAL_6]] : tensor<?x?xf32>
-// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[VAL_10:.*]] = vector.create_mask %[[VAL_5]], %[[VAL_7]] : vector<1x4xi1>
-// CHECK: %[[VAL_11:.*]] = vector.mask %[[VAL_10]] { vector.transfer_read %[[VAL_2]]{{\[}}%[[VAL_8]], %[[VAL_8]]], %[[VAL_9]] {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
-// CHECK: %[[VAL_12:.*]] = vector.step : vector<4xindex>
-// CHECK: %[[VAL_13:.*]] = vector.broadcast %[[VAL_1]] : index to vector<4xindex>
-// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_13]] : vector<4xindex>
-// CHECK-DAG: %[[VAL_15:.*]] = arith.constant dense<true> : vector<1x4xi1>
-// CHECK-DAG: %[[VAL_16:.*]] = arith.constant dense<0.000000e+00> : vector<1x4xf32>
-// CHECK-DAG: %[[VAL_17:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[VAL_18:.*]] = arith.constant dense<79> : vector<1x4xindex>
-// CHECK-DAG: %[[VAL_19:.*]] = arith.constant 1 : index
-// CHECK: %[[VAL_20:.*]] = tensor.dim %[[VAL_0]], %[[VAL_19]] : tensor<?x?xf32>
-// CHECK: %[[VAL_21:.*]] = vector.broadcast %[[VAL_20]] : index to vector<1x4xindex>
-// CHECK: %[[VAL_22:.*]] = arith.muli %[[VAL_18]], %[[VAL_21]] : vector<1x4xindex>
-// CHECK: %[[VAL_23:.*]] = vector.broadcast %[[VAL_14]] : vector<4xindex> to vector<1x4xindex>
-// CHECK: %[[VAL_24:.*]] = arith.addi %[[VAL_23]], %[[VAL_22]] : vector<1x4xindex>
-// CHECK: %[[VAL_25:.*]] = vector.mask %[[VAL_10]] { vector.gather %[[VAL_0]]{{\[}}%[[VAL_17]], %[[VAL_17]]] {{\[}}%[[VAL_24]]], %[[VAL_15]], %[[VAL_16]] : tensor<?x?xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
-// CHECK: %[[VAL_26:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_27:.*]] = vector.mask %[[VAL_10]] { vector.transfer_write %[[VAL_25]], %[[VAL_2]]{{\[}}%[[VAL_26]], %[[VAL_26]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<?x?xf32> } : vector<1x4xi1> -> tensor<?x?xf32>
-// CHECK: return %[[VAL_27]] : tensor<?x?xf32>
-// CHECK: }
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[OUTPUT:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[IDX:.*]]: index)
+
+/// Create the mask
+// CHECK: %[[C79:.*]] = arith.constant 79 : index
+// CHECK: %[[DIM_0_IDX:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM_0:.*]] = tensor.dim %[[OUTPUT]], %[[DIM_0_IDX]] : tensor<?x?xf32>
+// CHECK: %[[DIM_1_IDX:.*]] = arith.constant 1 : index
+// CHECK: %[[DIM_1:.*]] = tensor.dim %[[OUTPUT]], %[[DIM_1_IDX]] : tensor<?x?xf32>
+// CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM_0]], %[[DIM_1]] : vector<1x4xi1>
+
+/// TODO: This transfer_read is redundant - remove
+// CHECK: vector.mask %[[MASK]] { vector.transfer_read %[[OUTPUT]]{{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
+
+/// Caluclate the index vector
+// CHECK: %[[STEP:.*]] = vector.step : vector<4xindex>
+// CHECK: %[[IDX_BC:.*]] = vector.broadcast %[[IDX]] : index to vector<4xindex>
+// CHECK: %[[IDX_VEC:.*]] = arith.addi %[[STEP]], %[[IDX_BC]] : vector<4xindex>
+// CHECK: %[[C0:.*]] = arith.constant 0 : i32
+// CHECK: %[[SC:.*]] = vector.shape_cast %[[IDX_VEC]] : vector<4xindex> to vector<4xindex>
+
+/// Extract the starting point from the index vector
+// CHECK: %[[IDX_START:.*]] = vector.extractelement %[[SC]]{{\[}}%[[C0]] : i32] : vector<4xindex>
+
+// Final read and write
+// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]]{{\[}}%[[C79]], %[[IDX_START]]], {{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
+// CHECK: %[[VAL_24:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[OUTPUT]]{{.*}} {in_bounds = [true, true]} : vector<1x4xf32>, tensor<?x?xf32> } : vector<1x4xi1> -> tensor<?x?xf32>
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
@@ -177,6 +179,65 @@ module attributes {transform.with_named_sequence} {
// -----
+func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous_scalable(
+ %src: tensor<?x?xf32>,
+ %output : tensor<?x?xf32>,
+ %idx: index) -> tensor<?x?xf32> {
+
+ %c79 = arith.constant 79 : index
+ %1 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]
+ } outs(%output : tensor<?x?xf32>) {
+ ^bb0(%out: f32):
+ %2 = linalg.index 1 : index
+ %3 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %idx)
+ %extracted = tensor.extract %src[%c79, %3] : tensor<?x?xf32>
+ linalg.yield %extracted : f32
+ } -> tensor<?x?xf32>
+ return %1 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous_scalable(
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[OUTPUT:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[IDX:.*]]: index)
+
+/// Create the mask
+// CHECK: %[[C79:.*]] = arith.constant 79 : index
+// CHECK: %[[DIM_0_IDX:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM_0:.*]] = tensor.dim %[[OUTPUT]], %[[DIM_0_IDX]] : tensor<?x?xf32>
+// CHECK: %[[DIM_1_IDX:.*]] = arith.constant 1 : index
+// CHECK: %[[DIM_1:.*]] = tensor.dim %[[OUTPUT]], %[[DIM_1_IDX]] : tensor<?x?xf32>
+// CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM_0]], %[[DIM_1]] : vector<1x[4]xi1>
+
+/// TODO: This transfer_read is redundant - remove
+// CHECK: vector.mask %[[MASK]] { vector.transfer_read %[[OUTPUT]]{{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x[4]xf32> } : vector<1x[4]xi1> -> vector<1x[4]xf32>
+
+/// Caluclate the index vector
+// CHECK: %[[STEP:.*]] = vector.step : vector<[4]xindex>
+// CHECK: %[[IDX_BC:.*]] = vector.broadcast %[[IDX]] : index to vector<[4]xindex>
+// CHECK: %[[IDX_VEC:.*]] = arith.addi %[[STEP]], %[[IDX_BC]] : vector<[4]xindex>
+// CHECK: %[[C0:.*]] = arith.constant 0 : i32
+// CHECK: %[[SC:.*]] = vector.shape_cast %[[IDX_VEC]] : vector<[4]xindex> to vector<[4]xindex>
+
+/// Extract the starting point from the index vector
+// CHECK: %[[IDX_START:.*]] = vector.extractelement %[[SC]]{{\[}}%[[C0]] : i32] : vector<[4]xindex>
+
+// Final read and write
+// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]]{{\[}}%[[C79]], %[[IDX_START]]], {{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x[4]xf32> } : vector<1x[4]xi1> -> vector<1x[4]xf32>
+// CHECK: %[[VAL_24:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[OUTPUT]]{{.*}} {in_bounds = [true, true]} : vector<1x[4]xf32>, tensor<?x?xf32> } : vector<1x[4]xi1> -> tensor<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [1, [4]] {vectorize_nd_extract} : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
func.func @masked_vectorize_nd_tensor_extract_with_affine_apply_gather(%6: tensor<80x16xf32>, %arg0: index, %extracted_slice : tensor<1x3xf32>) -> tensor<1x3xf32> {
%c16 = arith.constant 16 : index
%1 = linalg.generic {
More information about the Mlir-commits
mailing list