[Mlir-commits] [mlir] 315ba77 - [mlir][linalg] Vectorisation of tensor.extract - dynamic shapes (#100582)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Sep 19 11:53:15 PDT 2024
Author: Andrzej WarzyĆski
Date: 2024-09-19T19:53:11+01:00
New Revision: 315ba7740663208f8bc45a7e4f145dc1df79500c
URL: https://github.com/llvm/llvm-project/commit/315ba7740663208f8bc45a7e4f145dc1df79500c
DIFF: https://github.com/llvm/llvm-project/commit/315ba7740663208f8bc45a7e4f145dc1df79500c.diff
LOG: [mlir][linalg] Vectorisation of tensor.extract - dynamic shapes (#100582)
This PR removes the assumption that reading from a dynamic tensor is
always a gather load:
```mlir
%extracted = tensor.extract %src[%c79, %3] : tensor<?x?xf32>
```
That assumption was originally introduced to simplify the implementation
and to reduce the number of cases to consider. Now that the
vectorisation of `tensor.extract` has been around for > 1 year and has
been quite stable, we can safely relax it.
This is a relatively small change - rather than using the parent linalg
Op to infer the target output shape (not possible with dynamic shapes),
the vectorizer will use the (previously constructed) output vector
shape instead.
As expected, the following test required updating (`vector.gather` ->
`vector.transfer_read`):
*
@masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous
Similar test for scalable vectors is also added.
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/vectorization-scalable.mlir
mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 63dcda78d0f2be..a376afa5ddab12 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -810,12 +810,12 @@ static Value calculateGatherOffset(RewriterBase &rewriter,
enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather };
-/// Checks whether /p val can be used for calculating a loop invariant index.
-static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
+/// Checks whether `val` can be used for calculating a loop invariant index.
+static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val,
+ VectorType resType) {
- auto targetShape = linalgOp.getStaticLoopRanges();
- assert(llvm::count_if(targetShape,
- [](int64_t dimSize) { return dimSize > 1; }) == 1 &&
+ assert(((llvm::count_if(resType.getShape(),
+ [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
"n-D vectors are not yet supported");
// Blocks outside _this_ linalg.generic are effectively loop invariant.
@@ -849,7 +849,7 @@ static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
bool result = true;
for (auto op : ancestor->getOperands())
- result &= isLoopInvariantIdx(linalgOp, op);
+ result &= isLoopInvariantIdx(linalgOp, op, resType);
return result;
}
@@ -871,10 +871,9 @@ static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
/// where <dim> is the trailing dim of the iteration space. \p foundIndexOp is
/// updated to `true` when such an op is found.
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
- bool &foundIndexOp) {
+ bool &foundIndexOp, VectorType resType) {
- auto targetShape = linalgOp.getStaticLoopRanges();
- assert(((llvm::count_if(targetShape,
+ assert(((llvm::count_if(resType.getShape(),
[](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
"n-D vectors are not yet supported");
@@ -910,44 +909,38 @@ static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
bool result = false;
for (auto op : ancestor->getOperands())
- result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp);
+ result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp, resType);
return result;
}
/// Infer the memory access pattern for the input ExtractOp
///
-/// Based on the operation shapes and indices (usually based on the iteration
-/// space of the parent `linalgOp` operation), decides whether the input
-/// ExtractOp is a contiguous load (including a broadcast of a scalar) or a
-/// gather load.
+/// Based on the ExtratOp result shape and the access indices, decides whether
+/// this Op corresponds to a contiguous load (including a broadcast of a scalar)
+/// or a gather load. When analysing the ExtractOp indices (to identify
+/// contiguous laods), this method looks for "loop" invariant indices (e.g.
+/// block arguments) and indices that change linearly (e.g. via `linalg.index`
+/// Op).
///
/// Note that it is always safe to use gather load operations for contiguous
/// loads (albeit slow), but not vice-versa. When in doubt, bail out and assume
/// that `extractOp` is a gather load.
static VectorMemoryAccessKind
getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
- LinalgOp &linalgOp) {
+ LinalgOp &linalgOp, VectorType resType) {
- auto targetShape = linalgOp.getStaticLoopRanges();
auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
- // 0.1 Is this a 0-D vector? If yes then this is a scalar broadcast.
+ // 0. Is this a 0-D vector? If yes then this is a scalar broadcast.
if (inputShape.getShape().empty())
return VectorMemoryAccessKind::ScalarBroadcast;
- // 0.2 In the case of dynamic shapes just bail-out and assume that it's a
- // gather load.
- // TODO: Relax this condition.
- if (linalgOp.hasDynamicShape())
- return VectorMemoryAccessKind::Gather;
-
// True for vectors that are effectively 1D, e.g. `vector<1x4x1xi32>`, false
// otherwise.
- bool isOutput1DVector = (llvm::count_if(targetShape, [](int64_t dimSize) {
- return dimSize > 1;
- }) == 1);
-
+ bool isOutput1DVector =
+ (llvm::count_if(resType.getShape(),
+ [](int64_t dimSize) { return dimSize > 1; }) == 1);
// 1. Assume that it's a gather load when reading non-1D vector.
if (!isOutput1DVector)
return VectorMemoryAccessKind::Gather;
@@ -965,7 +958,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
if (inputShape.getShape()[i] == 1)
continue;
- leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal);
+ leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal, resType);
}
if (!leadingIdxsLoopInvariant) {
@@ -982,7 +975,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
// 3a. Scalar broadcast load
// If the trailing index is loop invariant then this is a scalar load.
if (leadingIdxsLoopInvariant &&
- isLoopInvariantIdx(linalgOp, extractOpTrailingIdx)) {
+ isLoopInvariantIdx(linalgOp, extractOpTrailingIdx, resType)) {
LDBG("Found scalar broadcast load: " << extractOp);
return VectorMemoryAccessKind::ScalarBroadcast;
@@ -993,8 +986,8 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
// This effectively means that it must be based on the trailing loop index.
// This is what the following bool captures.
bool foundIndexOp = false;
- bool isContiguousLoad =
- isContiguousLoadIdx(linalgOp, extractOpTrailingIdx, foundIndexOp);
+ bool isContiguousLoad = isContiguousLoadIdx(linalgOp, extractOpTrailingIdx,
+ foundIndexOp, resType);
isContiguousLoad &= foundIndexOp;
if (isContiguousLoad) {
@@ -1035,7 +1028,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
rewriter.create<arith::ConstantIndexOp>(loc, 0));
VectorMemoryAccessKind memAccessKind =
- getTensorExtractMemoryAccessPattern(extractOp, linalgOp);
+ getTensorExtractMemoryAccessPattern(extractOp, linalgOp, resultType);
// 1. Handle gather access
if (memAccessKind == VectorMemoryAccessKind::Gather) {
diff --git a/mlir/test/Dialect/Linalg/vectorization-scalable.mlir b/mlir/test/Dialect/Linalg/vectorization-scalable.mlir
index 4ee3088cc37787..c3a30e3ee209e8 100644
--- a/mlir/test/Dialect/Linalg/vectorization-scalable.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization-scalable.mlir
@@ -162,17 +162,14 @@ func.func @vectorize_linalg_index(%arg0: tensor<3x3x?xf32>, %arg1: tensor<1x1x?x
// CHECK-LABEL: @vectorize_linalg_index
// CHECK-SAME: %[[SRC:.*]]: tensor<3x3x?xf32>, %[[DST:.*]]: tensor<1x1x?xf32>
-// CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x[4]xf32>
-// CHECK-DAG: %[[MASK:.*]] = arith.constant dense<true> : vector<1x1x[4]xi1>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[DST_DIM2:.*]] = tensor.dim %[[DST]], %[[C2]] : tensor<1x1x?xf32>
-// CHECK: %[[DST_MASK:.*]] = vector.create_mask %[[C1]], %[[C1]], %[[DST_DIM2]] : vector<1x1x[4]xi1>
+// CHECK: %[[MASK:.*]] = vector.create_mask %[[C1]], %[[C1]], %[[DST_DIM2]] : vector<1x1x[4]xi1>
// CHECK: %[[INDEX_VEC:.*]] = vector.step : vector<[4]xindex>
-// CHECK: %[[INDEX_VEC_BCAST:.*]] = vector.broadcast %[[INDEX_VEC]] : vector<[4]xindex> to vector<1x1x[4]xindex>
-// CHECK: %[[GATHER:.*]] = vector.mask %[[DST_MASK]] { vector.gather %[[SRC]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {{\[}}%[[INDEX_VEC_BCAST]]], %[[MASK]], %[[PASSTHRU]] : tensor<3x3x?xf32>, vector<1x1x[4]xindex>, vector<1x1x[4]xi1>, vector<1x1x[4]xf32> into vector<1x1x[4]xf32> } : vector<1x1x[4]xi1> -> vector<1x1x[4]xf32>
-// CHECK: %[[OUT:.*]] = vector.mask %[[DST_MASK]] { vector.transfer_write %[[GATHER]], %[[DST]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x[4]xf32>, tensor<1x1x?xf32> } : vector<1x1x[4]xi1> -> tensor<1x1x?xf32>
+// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]][%c0, %c0, %2], %cst {in_bounds = [true, true, true]} : tensor<3x3x?xf32>, vector<1x1x[4]xf32> } : vector<1x1x[4]xi1> -> vector<1x1x[4]xf32>
+// CHECK: %[[OUT:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[DST]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x[4]xf32>, tensor<1x1x?xf32> } : vector<1x1x[4]xi1> -> tensor<1x1x?xf32>
// CHECK: return %[[OUT]] : tensor<1x1x?xf32>
module attributes {transform.with_named_sequence} {
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir
index 964565620fd01f..31a754d9343682 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir
@@ -120,52 +120,54 @@ module attributes {transform.with_named_sequence} {
// -----
-func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous(%6: tensor<?x?xf32>, %arg0: index, %extracted_slice : tensor<?x?xf32>) -> tensor<?x?xf32> {
+func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous(
+ %src: tensor<?x?xf32>,
+ %output : tensor<?x?xf32>,
+ %idx: index) -> tensor<?x?xf32> {
+
%c79 = arith.constant 79 : index
%1 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
iterator_types = ["parallel", "parallel"]
- } outs(%extracted_slice : tensor<?x?xf32>) {
+ } outs(%output : tensor<?x?xf32>) {
^bb0(%out: f32):
%2 = linalg.index 1 : index
- %3 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %arg0)
- %extracted = tensor.extract %6[%c79, %3] : tensor<?x?xf32>
+ %3 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %idx)
+ %extracted = tensor.extract %src[%c79, %3] : tensor<?x?xf32>
linalg.yield %extracted : f32
} -> tensor<?x?xf32>
return %1 : tensor<?x?xf32>
}
// CHECK-LABEL: func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xf32>,
-// CHECK-SAME: %[[VAL_1:.*]]: index,
-// CHECK-SAME: %[[VAL_2:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
-// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 79 : index
-// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_2]], %[[VAL_4]] : tensor<?x?xf32>
-// CHECK: %[[VAL_6:.*]] = arith.constant 1 : index
-// CHECK: %[[VAL_7:.*]] = tensor.dim %[[VAL_2]], %[[VAL_6]] : tensor<?x?xf32>
-// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[VAL_10:.*]] = vector.create_mask %[[VAL_5]], %[[VAL_7]] : vector<1x4xi1>
-// CHECK: %[[VAL_11:.*]] = vector.mask %[[VAL_10]] { vector.transfer_read %[[VAL_2]]{{\[}}%[[VAL_8]], %[[VAL_8]]], %[[VAL_9]] {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
-// CHECK: %[[VAL_12:.*]] = vector.step : vector<4xindex>
-// CHECK: %[[VAL_13:.*]] = vector.broadcast %[[VAL_1]] : index to vector<4xindex>
-// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_13]] : vector<4xindex>
-// CHECK-DAG: %[[VAL_15:.*]] = arith.constant dense<true> : vector<1x4xi1>
-// CHECK-DAG: %[[VAL_16:.*]] = arith.constant dense<0.000000e+00> : vector<1x4xf32>
-// CHECK-DAG: %[[VAL_17:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[VAL_18:.*]] = arith.constant dense<79> : vector<1x4xindex>
-// CHECK-DAG: %[[VAL_19:.*]] = arith.constant 1 : index
-// CHECK: %[[VAL_20:.*]] = tensor.dim %[[VAL_0]], %[[VAL_19]] : tensor<?x?xf32>
-// CHECK: %[[VAL_21:.*]] = vector.broadcast %[[VAL_20]] : index to vector<1x4xindex>
-// CHECK: %[[VAL_22:.*]] = arith.muli %[[VAL_18]], %[[VAL_21]] : vector<1x4xindex>
-// CHECK: %[[VAL_23:.*]] = vector.broadcast %[[VAL_14]] : vector<4xindex> to vector<1x4xindex>
-// CHECK: %[[VAL_24:.*]] = arith.addi %[[VAL_23]], %[[VAL_22]] : vector<1x4xindex>
-// CHECK: %[[VAL_25:.*]] = vector.mask %[[VAL_10]] { vector.gather %[[VAL_0]]{{\[}}%[[VAL_17]], %[[VAL_17]]] {{\[}}%[[VAL_24]]], %[[VAL_15]], %[[VAL_16]] : tensor<?x?xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
-// CHECK: %[[VAL_26:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_27:.*]] = vector.mask %[[VAL_10]] { vector.transfer_write %[[VAL_25]], %[[VAL_2]]{{\[}}%[[VAL_26]], %[[VAL_26]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<?x?xf32> } : vector<1x4xi1> -> tensor<?x?xf32>
-// CHECK: return %[[VAL_27]] : tensor<?x?xf32>
-// CHECK: }
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[OUTPUT:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[IDX:.*]]: index)
+
+/// Create the mask
+// CHECK: %[[C79:.*]] = arith.constant 79 : index
+// CHECK: %[[DIM_0_IDX:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM_0:.*]] = tensor.dim %[[OUTPUT]], %[[DIM_0_IDX]] : tensor<?x?xf32>
+// CHECK: %[[DIM_1_IDX:.*]] = arith.constant 1 : index
+// CHECK: %[[DIM_1:.*]] = tensor.dim %[[OUTPUT]], %[[DIM_1_IDX]] : tensor<?x?xf32>
+// CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM_0]], %[[DIM_1]] : vector<1x4xi1>
+
+/// TODO: This transfer_read is redundant - remove
+// CHECK: vector.mask %[[MASK]] { vector.transfer_read %[[OUTPUT]]{{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
+
+/// Caluclate the index vector
+// CHECK: %[[STEP:.*]] = vector.step : vector<4xindex>
+// CHECK: %[[IDX_BC:.*]] = vector.broadcast %[[IDX]] : index to vector<4xindex>
+// CHECK: %[[IDX_VEC:.*]] = arith.addi %[[STEP]], %[[IDX_BC]] : vector<4xindex>
+// CHECK: %[[C0:.*]] = arith.constant 0 : i32
+// CHECK: %[[SC:.*]] = vector.shape_cast %[[IDX_VEC]] : vector<4xindex> to vector<4xindex>
+
+/// Extract the starting point from the index vector
+// CHECK: %[[IDX_START:.*]] = vector.extractelement %[[SC]]{{\[}}%[[C0]] : i32] : vector<4xindex>
+
+// Final read and write
+// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]]{{\[}}%[[C79]], %[[IDX_START]]], {{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
+// CHECK: %[[VAL_24:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[OUTPUT]]{{.*}} {in_bounds = [true, true]} : vector<1x4xf32>, tensor<?x?xf32> } : vector<1x4xi1> -> tensor<?x?xf32>
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
@@ -177,6 +179,65 @@ module attributes {transform.with_named_sequence} {
// -----
+func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous_scalable(
+ %src: tensor<?x?xf32>,
+ %output : tensor<?x?xf32>,
+ %idx: index) -> tensor<?x?xf32> {
+
+ %c79 = arith.constant 79 : index
+ %1 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
+ iterator_types = ["parallel", "parallel"]
+ } outs(%output : tensor<?x?xf32>) {
+ ^bb0(%out: f32):
+ %2 = linalg.index 1 : index
+ %3 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %idx)
+ %extracted = tensor.extract %src[%c79, %3] : tensor<?x?xf32>
+ linalg.yield %extracted : f32
+ } -> tensor<?x?xf32>
+ return %1 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous_scalable(
+// CHECK-SAME: %[[SRC:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[OUTPUT:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[IDX:.*]]: index)
+
+/// Create the mask
+// CHECK: %[[C79:.*]] = arith.constant 79 : index
+// CHECK: %[[DIM_0_IDX:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM_0:.*]] = tensor.dim %[[OUTPUT]], %[[DIM_0_IDX]] : tensor<?x?xf32>
+// CHECK: %[[DIM_1_IDX:.*]] = arith.constant 1 : index
+// CHECK: %[[DIM_1:.*]] = tensor.dim %[[OUTPUT]], %[[DIM_1_IDX]] : tensor<?x?xf32>
+// CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM_0]], %[[DIM_1]] : vector<1x[4]xi1>
+
+/// TODO: This transfer_read is redundant - remove
+// CHECK: vector.mask %[[MASK]] { vector.transfer_read %[[OUTPUT]]{{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x[4]xf32> } : vector<1x[4]xi1> -> vector<1x[4]xf32>
+
+/// Caluclate the index vector
+// CHECK: %[[STEP:.*]] = vector.step : vector<[4]xindex>
+// CHECK: %[[IDX_BC:.*]] = vector.broadcast %[[IDX]] : index to vector<[4]xindex>
+// CHECK: %[[IDX_VEC:.*]] = arith.addi %[[STEP]], %[[IDX_BC]] : vector<[4]xindex>
+// CHECK: %[[C0:.*]] = arith.constant 0 : i32
+// CHECK: %[[SC:.*]] = vector.shape_cast %[[IDX_VEC]] : vector<[4]xindex> to vector<[4]xindex>
+
+/// Extract the starting point from the index vector
+// CHECK: %[[IDX_START:.*]] = vector.extractelement %[[SC]]{{\[}}%[[C0]] : i32] : vector<[4]xindex>
+
+// Final read and write
+// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]]{{\[}}%[[C79]], %[[IDX_START]]], {{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x[4]xf32> } : vector<1x[4]xi1> -> vector<1x[4]xf32>
+// CHECK: %[[VAL_24:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[OUTPUT]]{{.*}} {in_bounds = [true, true]} : vector<1x[4]xf32>, tensor<?x?xf32> } : vector<1x[4]xi1> -> tensor<?x?xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [1, [4]] {vectorize_nd_extract} : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
func.func @masked_vectorize_nd_tensor_extract_with_affine_apply_gather(%6: tensor<80x16xf32>, %arg0: index, %extracted_slice : tensor<1x3xf32>) -> tensor<1x3xf32> {
%c16 = arith.constant 16 : index
%1 = linalg.generic {
More information about the Mlir-commits
mailing list