[Mlir-commits] [mlir] [mlir][linalg] Generate `vector.transfer_read` for contiguous `tensor.extract` loads (PR #76436)
Prathamesh Tagore
llvmlistbot at llvm.org
Mon Jan 1 00:24:27 PST 2024
https://github.com/meshtag updated https://github.com/llvm/llvm-project/pull/76436
>From 1b9576e05d36cfd95a5834e8b4e3dc76d583b130 Mon Sep 17 00:00:00 2001
From: meshtag <prathameshtagore at gmail.com>
Date: Mon, 25 Dec 2023 07:56:44 +0000
Subject: [PATCH 1/5] Add algorithm skeleton
---
.../Linalg/Transforms/Vectorization.cpp | 108 ++++--------------
1 file changed, 22 insertions(+), 86 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index c21d007c931b9b..783b455e4f0d06 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -788,9 +788,6 @@ enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather };
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
auto targetShape = linalgOp.getStaticLoopRanges();
- assert(((llvm::count_if(targetShape,
- [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
- "n-D vectors are not yet supported");
assert(targetShape.back() != 1 &&
"1-D vectors with the trailing dim eqaual 1 are not yet supported");
@@ -806,12 +803,8 @@ static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
Operation *defOp = val.getDefiningOp();
assert(defOp && "This is neither a block argument nor an operation result");
- // IndexOp is loop invariant as long as its result remains constant across
- // iterations. Given the assumptions on the loop ranges above, only the
- // trailing loop dim ever changes.
- auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp))
- return (indexOp.getDim() != trailingLoopDim);
+ return false;
auto *ancestor = block->findAncestorOpInBlock(*defOp);
@@ -830,50 +823,23 @@ static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
return result;
}
-/// Check whether \p val could be used for calculating the trailing index for a
-/// contiguous load operation.
-///
-/// There are currently 3 types of values that are allowed here:
-/// 1. loop-invariant values,
-/// 2. values that increment by 1 with every loop iteration,
-/// 3. results of basic arithmetic operations (linear and continuous)
-/// involving 1., 2. and 3.
-/// This method returns True if indeed only such values are used in calculating
-/// \p val.
-///
-/// Additionally, the trailing index for a contiguous load operation should
-/// increment by 1 with every loop iteration, i.e. be based on:
-/// * `linalg.index <dim>` ,
-/// where <dim> is the trailing dim of the iteration space. \p foundIndexOp is
-/// updated to `true` when such an op is found.
-static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
- bool &foundIndexOp) {
-
+static bool isProperLinalgIdx(LinalgOp &linalgOp, Value &val, int64_t valuePosInExtract) {
auto targetShape = linalgOp.getStaticLoopRanges();
- assert(((llvm::count_if(targetShape,
- [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
- "n-D vectors are not yet supported");
assert(targetShape.back() != 1 &&
"1-D vectors with the trailing dim 1 are not yet supported");
- // Blocks outside _this_ linalg.generic are effectively loop invariant.
- // However, analysing block arguments for _this_ linalg.generic Op is a bit
- // tricky. Just bail out in the latter case.
- // TODO: We could try analysing the corresponding affine map here.
auto *block = linalgOp.getBlock();
if (isa<BlockArgument>(val))
- return llvm::all_of(block->getArguments(),
- [&val](Value v) { return (v != val); });
+ return false;
Operation *defOp = val.getDefiningOp();
- assert(defOp && "This is neither a block argument nor an operation result");
+ assert(defOp && "This is not an operation result");
- // Given the assumption on the loop ranges above, only the trailing loop
- // index is not constant.
- auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
+ auto loopDim = linalgOp.getStaticLoopRanges().size() - valuePosInExtract;
if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
- foundIndexOp = (indexOp.getDim() == trailingLoopDim);
- return true;
+ if (indexOp.getDim() == loopDim) {
+ return true;
+ }
}
auto *ancestor = block->findAncestorOpInBlock(*defOp);
@@ -889,9 +855,9 @@ static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
bool result = false;
for (auto op : ancestor->getOperands())
- result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp);
+ result |= isProperLinalgIdx(linalgOp, op, valuePosInExtract);
- return result;
+ return result;
}
/// Check whether \p extractOp would be a gather or a contiguous load Op after
@@ -915,14 +881,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
if (linalgOp.hasDynamicShape())
return VectorMemoryAccessKind::Gather;
- // 1. Assume that it's a gather load when reading _into_:
- // * an n-D vector, like`tensor<1x2x4xi32` or`tensor<2x1x4xi32>`, or
- // * a 1-D vector with the trailing dim equal 1, e.g. `tensor<1x4x1xi32`.
- // TODO: Relax these conditions.
- // FIXME: This condition assumes non-dynamic sizes.
- if ((llvm::count_if(targetShape,
- [](int64_t dimSize) { return dimSize > 1; }) != 1) ||
- targetShape.back() == 1)
+ if (targetShape.back() == 1)
return VectorMemoryAccessKind::Gather;
// 2. Assume that it's a gather load when reading _from_ a tensor for which
@@ -931,51 +890,28 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
if (inputShape.getShape().back() == 1)
return VectorMemoryAccessKind::Gather;
- bool leadingIdxsLoopInvariant = true;
+ bool isLoopInvariantLoad = true;
+ bool isProperLinalgIdxLoad = true;
- // 3. Analyze the leading indices of `extractOp`.
- // Look at the way each index is calculated and decide whether it is suitable
- // for a contiguous load, i.e. whether it's loop invariant.
auto indices = extractOp.getIndices();
- auto leadIndices = indices.drop_back(1);
-
- for (auto [i, indexVal] : llvm::enumerate(leadIndices)) {
+ for (auto [i, indexVal] : llvm::enumerate(indices)) {
if (inputShape.getShape()[i] == 1)
continue;
- leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal);
- }
+ isLoopInvariantLoad &= isLoopInvariantIdx(linalgOp, indexVal);
+ isProperLinalgIdxLoad &= isProperLinalgIdx(linalgOp, indexVal, i);
- if (!leadingIdxsLoopInvariant) {
- LDBG("Found gather load: " << extractOp);
- return VectorMemoryAccessKind::Gather;
+ if (!isLoopInvariantLoad && !isProperLinalgIdxLoad) {
+ LDBG("Found gather load: " << extractOp);
+ return VectorMemoryAccessKind::Gather;
+ }
}
- // 4. Analyze the trailing index for `extractOp`.
- // At this point we know that the leading indices are loop invariant. This
- // means that is potentially a scalar or a contiguous load. We can decide
- // based on the trailing idx.
- auto extractOpTrailingIdx = indices.back();
-
- // 4a. Scalar broadcast load
- // If the trailing index is loop invariant then this is a scalar load.
- if (leadingIdxsLoopInvariant &&
- isLoopInvariantIdx(linalgOp, extractOpTrailingIdx)) {
+ if (isLoopInvariantLoad) {
LDBG("Found scalar broadcast load: " << extractOp);
-
return VectorMemoryAccessKind::ScalarBroadcast;
}
-
- // 4b. Contiguous loads
- // The trailing `extractOp` index should increment with every loop iteration.
- // This effectively means that it must be based on the trailing loop index.
- // This is what the following bool captures.
- bool foundIndexOp = false;
- bool isContiguousLoad =
- isContiguousLoadIdx(linalgOp, extractOpTrailingIdx, foundIndexOp);
- isContiguousLoad &= foundIndexOp;
-
- if (isContiguousLoad) {
+ else if (!isLoopInvariantLoad && isProperLinalgIdxLoad) {
LDBG("Found contigous load: " << extractOp);
return VectorMemoryAccessKind::Contiguous;
}
>From ddfd0de1d7911017dc54058b9e8a29c9a62a86c5 Mon Sep 17 00:00:00 2001
From: meshtag <prathameshtagore at gmail.com>
Date: Tue, 26 Dec 2023 05:54:47 +0000
Subject: [PATCH 2/5] First draft
---
.../Linalg/Transforms/Vectorization.cpp | 59 +++--
.../Linalg/vectorize-tensor-extract.mlir | 250 +++++++++++++++---
2 files changed, 249 insertions(+), 60 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 783b455e4f0d06..0d0b1ef0d085df 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -26,6 +26,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/RegionUtils.h"
+#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"
@@ -803,8 +804,20 @@ static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
Operation *defOp = val.getDefiningOp();
assert(defOp && "This is neither a block argument nor an operation result");
- if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp))
+ if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
+ // If target shape is of the form 1x1x1x..xn and val is obtained from a
+ // linalg.index op, it will be loop invariant only if index op's dim is not
+ // the trailing dimension.
+ if (llvm::count_if(targetShape,
+ [](int64_t dimSize) { return dimSize > 1; }) == 1 &&
+ targetShape.back() != 1) {
+ auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
+ return indexOp.getDim() != trailingLoopDim;
+ }
+ // val will be loop variant in some of the other cases.
+ // TODO: Relax this condition
return false;
+ }
auto *ancestor = block->findAncestorOpInBlock(*defOp);
@@ -823,11 +836,17 @@ static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
return result;
}
-static bool isProperLinalgIdx(LinalgOp &linalgOp, Value &val, int64_t valuePosInExtract) {
+// Determine if the val is obtained from a linalg.index op for the dimension at
+// which it is used to extract a value from the tensor and if it could be used
+// for contigous memory access.
+static bool isProperLinalgIdx(LinalgOp &linalgOp, Value &val,
+ uint64_t valuePosInExtract) {
auto targetShape = linalgOp.getStaticLoopRanges();
assert(targetShape.back() != 1 &&
"1-D vectors with the trailing dim 1 are not yet supported");
+ // val can't be a result of linalg.index for this linalg.generic if it is a
+ // block argument.
auto *block = linalgOp.getBlock();
if (isa<BlockArgument>(val))
return false;
@@ -835,11 +854,17 @@ static bool isProperLinalgIdx(LinalgOp &linalgOp, Value &val, int64_t valuePosIn
Operation *defOp = val.getDefiningOp();
assert(defOp && "This is not an operation result");
- auto loopDim = linalgOp.getStaticLoopRanges().size() - valuePosInExtract;
if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
- if (indexOp.getDim() == loopDim) {
- return true;
+ // If target shape is of the form 1x1x1x..xn and val is obtained from a
+ // linalg.index op, it will be used for contiguous access only when it is
+ // obtained for the trailing dimension.
+ if (llvm::count_if(targetShape,
+ [](int64_t dimSize) { return dimSize > 1; }) == 1 &&
+ targetShape.back() != 1) {
+ auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
+ return indexOp.getDim() == trailingLoopDim;
}
+ return indexOp.getDim() == valuePosInExtract;
}
auto *ancestor = block->findAncestorOpInBlock(*defOp);
@@ -848,7 +873,7 @@ static bool isProperLinalgIdx(LinalgOp &linalgOp, Value &val, int64_t valuePosIn
return false;
// Conservatively reject Ops that could lead to indices with stride other
- // than 1.
+ // than 1 after processing the result of linalg.index.
if (!isa<arith::AddIOp, arith::SubIOp, arith::ConstantOp, linalg::IndexOp>(
ancestor))
return false;
@@ -857,7 +882,7 @@ static bool isProperLinalgIdx(LinalgOp &linalgOp, Value &val, int64_t valuePosIn
for (auto op : ancestor->getOperands())
result |= isProperLinalgIdx(linalgOp, op, valuePosInExtract);
- return result;
+ return result;
}
/// Check whether \p extractOp would be a gather or a contiguous load Op after
@@ -899,7 +924,9 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
continue;
isLoopInvariantLoad &= isLoopInvariantIdx(linalgOp, indexVal);
- isProperLinalgIdxLoad &= isProperLinalgIdx(linalgOp, indexVal, i);
+ isProperLinalgIdxLoad &= !isLoopInvariantLoad
+ ? isProperLinalgIdx(linalgOp, indexVal, i)
+ : isProperLinalgIdxLoad;
if (!isLoopInvariantLoad && !isProperLinalgIdxLoad) {
LDBG("Found gather load: " << extractOp);
@@ -910,8 +937,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
if (isLoopInvariantLoad) {
LDBG("Found scalar broadcast load: " << extractOp);
return VectorMemoryAccessKind::ScalarBroadcast;
- }
- else if (!isLoopInvariantLoad && isProperLinalgIdxLoad) {
+ } else if (!isLoopInvariantLoad && isProperLinalgIdxLoad) {
LDBG("Found contigous load: " << extractOp);
return VectorMemoryAccessKind::Contiguous;
}
@@ -984,9 +1010,6 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
// * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
// (0th) element and use that.
SmallVector<Value> transferReadIdxs;
- auto resTrailingDim = resultType.getShape().back();
- auto zero = rewriter.create<arith::ConstantOp>(
- loc, rewriter.getI32Type(), rewriter.getZeroAttr(rewriter.getI32Type()));
for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
auto idx = bvm.lookup(extractOp.getIndices()[i]);
if (idx.getType().isIndex()) {
@@ -994,11 +1017,11 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
continue;
}
- auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
- loc, VectorType::get({resTrailingDim}, rewriter.getIndexType()),
- bvm.lookup(extractOp.getIndices()[i]));
- transferReadIdxs.push_back(
- rewriter.create<vector::ExtractElementOp>(loc, indexAs1dVector, zero));
+ auto idxShapedType = dyn_cast<ShapedType>(idx.getType());
+ SmallVector<int64_t> extractIndicesVec(idxShapedType.getRank(), 0);
+
+ transferReadIdxs.push_back(rewriter.create<vector::ExtractOp>(
+ loc, idx, ArrayRef<int64_t>(extractIndicesVec)));
}
// `tensor.extract_element` is always in-bounds, hence the following holds.
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index 3fd4fcd536624c..0ac67ca6af6ca7 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -92,17 +92,19 @@ func.func @vectorize_nd_tensor_extract_transfer_read_basic(%arg0: tensor<3x3x3xf
// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_read_basic
// CHECK-SAME: %[[ARG0:.*]]: tensor<3x3x3xf32>
-// CHECK-SAME: %[[ARG1:.*]]: tensor<1x1x3xf32>
-// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1x1x3xindex>
-// CHECK: %[[C0_i32:.*]] = arith.constant 0 : i32
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[CST_0:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[IDX_VEC0:.*]] = vector.shape_cast %[[CST]] : vector<1x1x3xindex> to vector<3xindex>
-// CHECK: %[[IDX1:.*]] = vector.extractelement %[[IDX_VEC0]][%[[C0_i32]] : i32] : vector<3xindex>
-// CHECK: %[[IDX_VEC:.*]] = vector.shape_cast %[[CST]] : vector<1x1x3xindex> to vector<3xindex>
-// CHECK: %[[IDX2:.*]] = vector.extractelement %[[IDX_VEC]][%[[C0_i32]] : i32] : vector<3xindex>
-// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[IDX1]], %[[IDX2]], %[[C0:.*]]], %[[CST_0]] {in_bounds = [true, true, true]} : tensor<3x3x3xf32>, vector<1x1x3xf32>
-// CHECK: vector.transfer_write %[[READ]], %[[ARG1]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x3xf32>, tensor<1x1x3xf32>
+// CHECK-SAME: %[[ARG1:.*]]: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> {
+// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
+// CHECK: %[[CST_0:.*]] = arith.constant dense<[0, 1, 2]> : vector<3xindex>
+// CHECK: %[[CST_1:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[E0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
+// CHECK: %[[E1:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
+// CHECK: %[[E2:.*]] = vector.extract %[[CST_0]][0] : index from vector<3xindex>
+// CHECK: %[[R1:.*]] = vector.transfer_read %[[ARG0]][%[[E0]], %[[E1]], %[[E2]]], %[[CST_1]] {in_bounds = [true, true, true]} : tensor<3x3x3xf32>, vector<1x1x3xf32>
+// CHECK: %[[RES:.*]] = vector.transfer_write %[[R1]], %[[ARG1]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x3xf32>, tensor<1x1x3xf32>
+// CHECK: return %[[RES]] : tensor<1x1x3xf32>
+// CHECK: }
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
@@ -113,7 +115,7 @@ module attributes {transform.with_named_sequence} {
}
}
- // -----
+// -----
func.func @vectorize_nd_tensor_extract_transfer_read_complex(%6: tensor<45x80x16xf32>, %arg0: index, %arg2: index, %arg1: index, %arg4: index, %extracted_slice : tensor<1x4xf32>) -> tensor<1x4xf32> {
%c79 = arith.constant 79 : index
@@ -134,26 +136,21 @@ func.func @vectorize_nd_tensor_extract_transfer_read_complex(%6: tensor<45x80x16
return %25 : tensor<1x4xf32>
}
-
-// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_read_complex(
+/// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_read_complex(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<45x80x16xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: index, %[[VAL_2:.*]]: index, %[[VAL_3:.*]]: index, %[[VAL_4:.*]]: index,
// CHECK-SAME: %[[VAL_5:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
// CHECK: %[[VAL_6:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
-// CHECK: %[[VAL_7:.*]] = arith.constant 0 : i32
// CHECK: %[[VAL_8:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[VAL_9:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_10:.*]] = arith.constant 79 : index
// CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_1]], %[[VAL_2]] : index
-// CHECK: %[[VAL_12:.*]] = vector.broadcast %[[VAL_11]] : index to vector<1x4xindex>
// CHECK: %[[VAL_13:.*]] = vector.broadcast %[[VAL_3]] : index to vector<4xindex>
// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_6]] : vector<4xindex>
// CHECK: %[[VAL_15:.*]] = vector.broadcast %[[VAL_4]] : index to vector<4xindex>
// CHECK: %[[VAL_16:.*]] = arith.addi %[[VAL_14]], %[[VAL_15]] : vector<4xindex>
-// CHECK: %[[VAL_17:.*]] = vector.shape_cast %[[VAL_12]] : vector<1x4xindex> to vector<4xindex>
-// CHECK: %[[VAL_18:.*]] = vector.extractelement %[[VAL_17]]{{\[}}%[[VAL_7]] : i32] : vector<4xindex>
-// CHECK: %[[VAL_19:.*]] = vector.extractelement %[[VAL_16]]{{\[}}%[[VAL_7]] : i32] : vector<4xindex>
-// CHECK: %[[VAL_20:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_18]], %[[VAL_10]], %[[VAL_19]]], %[[VAL_8]] {in_bounds = [true, true]} : tensor<45x80x16xf32>, vector<1x4xf32>
+// CHECK: %[[VAL_18:.*]] = vector.extract %[[VAL_16]][0] : index from vector<4xindex>
+// CHECK: %[[VAL_20:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_11]], %[[VAL_10]], %[[VAL_18]]], %[[VAL_8]] {in_bounds = [true, true]} : tensor<45x80x16xf32>, vector<1x4xf32>
// CHECK: %[[VAL_21:.*]] = vector.transfer_write %[[VAL_20]], %[[VAL_5]]{{\[}}%[[VAL_9]], %[[VAL_9]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
// CHECK: return %[[VAL_21]] : tensor<1x4xf32>
// CHECK: }
@@ -239,19 +236,21 @@ func.func @vectorize_nd_tensor_extract_contiguous_and_gather(%arg0: tensor<6xf32
// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_contiguous_and_gather(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<6xf32>
// CHECK-SAME: %[[VAL_1:.*]]: tensor<5xi32>
-// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK: %[[CST:.*]] = arith.constant dense<[0, 1, 2, 3, 4]> : vector<5xindex>
// CHECK: %[[VAL_3:.*]] = arith.constant 0 : i32
// CHECK: %[[VAL_4:.*]] = arith.constant dense<0> : vector<5xindex>
// CHECK: %[[VAL_5:.*]] = arith.constant dense<5> : vector<5xindex>
// CHECK: %[[VAL_6:.*]] = arith.constant dense<true> : vector<5xi1>
// CHECK: %[[VAL_7:.*]] = arith.constant dense<0.000000e+00> : vector<5xf32>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_8:.*]] = tensor.empty() : tensor<5xf32>
-// CHECK: %[[VAL_9:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_2]]], %[[VAL_3]] {in_bounds = [true]} : tensor<5xi32>, vector<5xi32>
+// CHECK: %[[E0:.*]] = vector.extract %[[CST]][0] : index from vector<5xindex>
+// CHECK: %[[VAL_9:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[E0]]], %[[VAL_3]] {in_bounds = [true]} : tensor<5xi32>, vector<5xi32>
// CHECK: %[[VAL_10:.*]] = arith.index_cast %[[VAL_9]] : vector<5xi32> to vector<5xindex>
// CHECK: %[[VAL_11:.*]] = arith.maxsi %[[VAL_10]], %[[VAL_4]] : vector<5xindex>
// CHECK: %[[VAL_12:.*]] = arith.minsi %[[VAL_11]], %[[VAL_5]] : vector<5xindex>
-// CHECK: %[[VAL_13:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_2]]] {{\[}}%[[VAL_12]]], %[[VAL_6]], %[[VAL_7]] : tensor<6xf32>, vector<5xindex>, vector<5xi1>, vector<5xf32> into vector<5xf32>
-// CHECK: %[[VAL_14:.*]] = vector.transfer_write %[[VAL_13]], %[[VAL_8]]{{\[}}%[[VAL_2]]] {in_bounds = [true]} : vector<5xf32>, tensor<5xf32>
+// CHECK: %[[VAL_13:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[C0]]] {{\[}}%[[VAL_12]]], %[[VAL_6]], %[[VAL_7]] : tensor<6xf32>, vector<5xindex>, vector<5xi1>, vector<5xf32> into vector<5xf32>
+// CHECK: %[[VAL_14:.*]] = vector.transfer_write %[[VAL_13]], %[[VAL_8]]{{\[}}%[[C0]]] {in_bounds = [true]} : vector<5xf32>, tensor<5xf32>
// CHECK: return %[[VAL_14]] : tensor<5xf32>
module attributes {transform.with_named_sequence} {
@@ -286,13 +285,12 @@ func.func @vectorize_nd_tensor_extract_with_affine_apply_contiguous(%6: tensor<8
// CHECK-SAME: %[[VAL_1:.*]]: index,
// CHECK-SAME: %[[VAL_2:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
// CHECK: %[[VAL_3:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
-// CHECK: %[[VAL_4:.*]] = arith.constant 0 : i32
// CHECK: %[[VAL_5:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[VAL_6:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_7:.*]] = arith.constant 79 : index
// CHECK: %[[VAL_8:.*]] = vector.broadcast %[[VAL_1]] : index to vector<4xindex>
// CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_3]] : vector<4xindex>
-// CHECK: %[[VAL_10:.*]] = vector.extractelement %[[VAL_9]]{{\[}}%[[VAL_4]] : i32] : vector<4xindex>
+// CHECK: %[[VAL_10:.*]] = vector.extract %[[VAL_9]][0] : index from vector<4xindex>
// CHECK: %[[VAL_11:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_7]], %[[VAL_10]]], %[[VAL_5]] {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32>
// CHECK: %[[VAL_12:.*]] = vector.transfer_write %[[VAL_11]], %[[VAL_2]]{{\[}}%[[VAL_6]], %[[VAL_6]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
// CHECK: return %[[VAL_12]] : tensor<1x4xf32>
@@ -331,16 +329,31 @@ func.func @vectorize_nd_tensor_extract_with_tensor_extract(%input_1: tensor<1x20
}
// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_with_tensor_extract(
-// CHECK-SAME: %[[INPUT_1:.*]]: tensor<1x20xi32>,
-// CHECK-SAME: %[[INPUT_2:.*]]: tensor<257x24xf32>,
-// CHECK: %[[EXTRACTED_0_IDX_0:.*]] = arith.constant 0 : index
-// CHECK: %[[EXTRACTED_0_IDX_1:.*]] = vector.extractelement %{{.*}}[%{{.*}} : i32] : vector<4xindex>
-// First `tensor.extract` from the generic Op - loop invariant scalar load.
-// CHECK: tensor.extract %[[INPUT_1]][%[[EXTRACTED_0_IDX_0]], %[[EXTRACTED_0_IDX_1]]] : tensor<1x20xi32>
-// The following `tensor.extract` from the generic Op s a contiguous load (all Ops used
-// for address calculation also satisfy the required conditions).
-// CHECK: vector.transfer_read %[[INPUT_2]][%{{.*}}, %{{.*}}, %{{.*}} {in_bounds = [true, true]} : tensor<257x24xf32>, vector<1x4xf32>
-
+// CHECK-SAME: %[[ARG0:.*]]: tensor<1x20xi32>, %[[ARG1:.*]]: tensor<257x24xf32>, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index) -> tensor<1x1x4xf32> {
+// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1x1x4xindex>
+// CHECK: %[[CST0:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
+// CHECK: %[[CST1:.*]] = arith.constant dense<256> : vector<1x1x4xindex>
+// CHECK: %[[CST2:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL0:.*]] = tensor.empty() : tensor<1x1x4xf32>
+// CHECK: %[[VAL1:.*]] = arith.addi %[[ARG2]], %[[ARG4]] : index
+// CHECK: %[[VAL2:.*]] = vector.broadcast %[[ARG3]] : index to vector<1x1x4xindex>
+// CHECK: %[[VAL3:.*]] = vector.broadcast %[[CST0]] : vector<4xindex> to vector<1x1x4xindex>
+// CHECK: %[[VAL4:.*]] = arith.addi %[[VAL2]], %[[VAL3]] : vector<1x1x4xindex>
+// CHECK: %[[VAL5:.*]] = vector.broadcast %[[ARG5]] : index to vector<1x1x4xindex>
+// CHECK: %[[VAL6:.*]] = arith.addi %[[VAL4]], %[[VAL5]] : vector<1x1x4xindex>
+// CHECK: %[[EXTRACTED:.*]] = tensor.extract %[[ARG0]][%[[C0]], %[[VAL1]]] : tensor<1x20xi32>
+// CHECK: %[[VAL7:.*]] = arith.index_cast %[[EXTRACTED]] : i32 to index
+// CHECK: %[[VAL8:.*]] = vector.broadcast %[[VAL7]] : index to vector<1x1x4xindex>
+// CHECK: %[[VAL9:.*]] = arith.maxsi %[[VAL8]], %[[CST]] : vector<1x1x4xindex>
+// CHECK: %[[VAL10:.*]] = arith.minsi %[[VAL9]], %[[CST1]] : vector<1x1x4xindex>
+// CHECK: %[[VAL11:.*]] = vector.extract %[[VAL10]][0, 0, 0] : index from vector<1x1x4xindex>
+// CHECK: %[[VAL12:.*]] = vector.extract %[[VAL6]][0, 0, 0] : index from vector<1x1x4xindex>
+// CHECK: %[[VAL13:.*]] = vector.transfer_read %[[ARG1]][%[[VAL11]], %[[VAL12]]], %[[CST2]] {in_bounds = [true, true]} : tensor<257x24xf32>, vector<1x4xf32>
+// CHECK: %[[VAL14:.*]] = vector.broadcast %[[VAL13]] : vector<1x4xf32> to vector<1x1x4xf32>
+// CHECK: %[[VAL15:.*]] = vector.transfer_write %[[VAL14]], %[[VAL0]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x4xf32>, tensor<1x1x4xf32>
+// CHECK: return %[[VAL15]] : tensor<1x1x4xf32>
+// CHECK: }
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
@@ -461,13 +474,13 @@ func.func @vectorize_nd_tensor_extract_with_maxsi_contiguous(%arg0: tensor<80x16
// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_with_maxsi_contiguous(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<80x16xf32>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
-// CHECK: %[[VAL_2:.*]] = arith.constant dense<16> : vector<1x4xindex>
-// CHECK: %[[VAL_3:.*]] = arith.constant 0 : i32
-// CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_2:.*]] = arith.constant dense<16> : vector<4x1xindex>
+// CHECK: %[[VAL_3:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
// CHECK: %[[VAL_5:.*]] = arith.constant 0.000000e+00 : f32
-// CHECK: %[[VAL_6:.*]] = vector.shape_cast %[[VAL_2]] : vector<1x4xindex> to vector<4xindex>
-// CHECK: %[[VAL_7:.*]] = vector.extractelement %[[VAL_6]]{{\[}}%[[VAL_3]] : i32] : vector<4xindex>
-// CHECK: %[[VAL_8:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_7]], %[[VAL_4]]], %[[VAL_5]] {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32>
+// CHECK: %[[VAL_4:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_7:.*]] = vector.extract %[[VAL_2]][0, 0] : index from vector<4x1xindex>
+// CHECK: %[[VAL_6:.*]] = vector.extract %[[VAL_3]][0] : index from vector<4xindex>
+// CHECK: %[[VAL_8:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_7]], %[[VAL_6]]], %[[VAL_5]] {in_bounds = [true, true]} : tensor<80x16xf32>, vector<1x4xf32>
// CHECK: %[[VAL_9:.*]] = vector.transfer_write %[[VAL_8]], %[[VAL_1]]{{\[}}%[[VAL_4]], %[[VAL_4]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
// CHECK: return %[[VAL_9]] : tensor<1x4xf32>
// CHECK: }
@@ -550,3 +563,156 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+func.func @vectorize_nd_tensor_extract_contigous(%arg0: tensor<80x16x17x18x19xf32>, %extracted_slice : tensor<4x5x6x7x8xf32>) -> tensor<4x5x6x7x8xf32> {
+ %1 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]
+ } outs(%extracted_slice : tensor<4x5x6x7x8xf32>) {
+ ^bb0(%out: f32):
+ %2 = linalg.index 0 : index
+ %3 = linalg.index 1 : index
+ %4 = linalg.index 2 : index
+ %5 = linalg.index 3 : index
+ %6 = linalg.index 4 : index
+ %extracted = tensor.extract %arg0[%2, %3, %4, %5, %6] : tensor<80x16x17x18x19xf32>
+ linalg.yield %extracted : f32
+ } -> tensor<4x5x6x7x8xf32>
+ return %1 : tensor<4x5x6x7x8xf32>
+}
+
+// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_contigous(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<80x16x17x18x19xf32>, %[[ARG1:.*]]: tensor<4x5x6x7x8xf32>) -> tensor<4x5x6x7x8xf32> {
+// CHECK: %[[CST:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
+// CHECK: %[[CST0:.*]] = arith.constant dense<[0, 1, 2, 3, 4]> : vector<5xindex>
+// CHECK: %[[CST1:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5]> : vector<6xindex>
+// CHECK: %[[CST2:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6]> : vector<7xindex>
+// CHECK: %[[CST3:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
+// CHECK: %[[CST4:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL0:.*]] = vector.extract %[[CST]][0] : index from vector<4xindex>
+// CHECK: %[[VAL1:.*]] = vector.extract %[[CST0]][0] : index from vector<5xindex>
+// CHECK: %[[VAL2:.*]] = vector.extract %[[CST1]][0] : index from vector<6xindex>
+// CHECK: %[[VAL3:.*]] = vector.extract %[[CST2]][0] : index from vector<7xindex>
+// CHECK: %[[VAL4:.*]] = vector.extract %[[CST3]][0] : index from vector<8xindex>
+// CHECK: %[[VAL5:.*]] = vector.transfer_read %[[ARG0]][%[[VAL0]], %[[VAL1]], %[[VAL2]], %[[VAL3]], %[[VAL4]]], %[[CST4]] {in_bounds = [true, true, true, true, true]} : tensor<80x16x17x18x19xf32>, vector<4x5x6x7x8xf32>
+// CHECK: %[[VAL6:.*]] = vector.transfer_write %[[VAL5]], %arg1[%[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true, true, true]} : vector<4x5x6x7x8xf32>, tensor<4x5x6x7x8xf32>
+// CHECK: return %[[VAL6]] : tensor<4x5x6x7x8xf32>
+// CHECK: }
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @vectorize_nd_tensor_extract_gather(%arg0: tensor<80x16x17x18x19xf32>, %extracted_slice : tensor<4x5x6x7x8xf32>) -> tensor<4x5x6x7x8xf32> {
+ %1 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]
+ } outs(%extracted_slice : tensor<4x5x6x7x8xf32>) {
+ ^bb0(%out: f32):
+ %2 = linalg.index 0 : index
+ %3 = linalg.index 1 : index
+ %4 = linalg.index 2 : index
+ %5 = linalg.index 3 : index
+ %extracted = tensor.extract %arg0[%2, %3, %4, %5, %5] : tensor<80x16x17x18x19xf32>
+ linalg.yield %extracted : f32
+ } -> tensor<4x5x6x7x8xf32>
+ return %1 : tensor<4x5x6x7x8xf32>
+}
+
+// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_gather(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<80x16x17x18x19xf32>, %[[ARG1:.*]]: tensor<4x5x6x7x8xf32>) -> tensor<4x5x6x7x8xf32> {
+// CHECK: %[[CST:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]]] [%{{.*}}], %{{.*}}, %{{.*}} : tensor<80x16x17x18x19xf32>, vector<4x5x6x7x8xindex>, vector<4x5x6x7x8xi1>, vector<4x5x6x7x8xf32> into vector<4x5x6x7x8xf32>
+// CHECK: %{{.*}} = vector.transfer_write %[[VAL]], %[[ARG1]][%[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true, true, true]} : vector<4x5x6x7x8xf32>, tensor<4x5x6x7x8xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+ // -----
+
+func.func @vectorize_nd_tensor_extract_contigous_complex(%6: tensor<45x80x16x17xf32>, %arg0: index, %arg1: index, %arg2: index, %arg3: index, %extracted_slice : tensor<1x4x5x6xf32>) -> tensor<1x4x5x6xf32> {
+ %0 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+ } outs(%extracted_slice : tensor<1x4x5x6xf32>) {
+ ^bb0(%out: f32):
+ %1 = linalg.index 0 : index
+ %2 = linalg.index 1 : index
+ %3 = linalg.index 2 : index
+ %4 = linalg.index 3 : index
+
+ %21 = arith.addi %arg0, %1 : index
+ %22 = arith.addi %21, %arg1 : index
+
+ %23 = arith.addi %arg0, %2 : index
+ %24 = arith.addi %23, %arg2 : index
+
+ %25 = arith.addi %arg1, %3 : index
+ %26 = arith.addi %arg3, %25 : index
+
+ %27 = arith.addi %arg2, %4 : index
+ %28 = arith.addi %arg3, %27 : index
+
+ %extracted = tensor.extract %6[%22, %24, %26, %28] : tensor<45x80x16x17xf32>
+ linalg.yield %extracted : f32
+ } -> tensor<1x4x5x6xf32>
+ return %0 : tensor<1x4x5x6xf32>
+}
+
+// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_contigous_complex(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<45x80x16x17xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: tensor<1x4x5x6xf32>) -> tensor<1x4x5x6xf32> {
+// CHECK: %[[CST:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
+// CHECK: %[[CST0:.*]] = arith.constant dense<[0, 1, 2, 3, 4]> : vector<5xindex>
+// CHECK: %[[CST1:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5]> : vector<6xindex>
+// CHECK: %[[CST2:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL0:.*]] = vector.broadcast %[[CST]] : vector<4xindex> to vector<1x6x5x4xindex>
+// CHECK: %[[VAL1:.*]] = vector.transpose %[[VAL0]], [0, 3, 2, 1] : vector<1x6x5x4xindex> to vector<1x4x5x6xindex>
+// CHECK: %[[VAL2:.*]] = vector.broadcast %[[CST0]] : vector<5xindex> to vector<1x4x6x5xindex>
+// CHECK: %[[VAL3:.*]] = vector.transpose %[[VAL2]], [0, 1, 3, 2] : vector<1x4x6x5xindex> to vector<1x4x5x6xindex>
+// CHECK: %[[VAL4:.*]] = arith.addi %[[ARG1]], %[[ARG2]] : index
+// CHECK: %[[VAL5:.*]] = vector.broadcast %[[ARG1]] : index to vector<1x4x5x6xindex>
+// CHECK: %[[VAL6:.*]] = arith.addi %[[VAL5]], %[[VAL1]] : vector<1x4x5x6xindex>
+// CHECK: %[[VAL7:.*]] = vector.broadcast %[[ARG3]] : index to vector<1x4x5x6xindex>
+// CHECK: %[[VAL8:.*]] = arith.addi %[[VAL6]], %[[VAL7]] : vector<1x4x5x6xindex>
+// CHECK: %[[VAL9:.*]] = vector.broadcast %[[ARG2]] : index to vector<1x4x5x6xindex>
+// CHECK: %[[VAL10:.*]] = arith.addi %[[VAL9]], %[[VAL3]] : vector<1x4x5x6xindex>
+// CHECK: %[[VAL11:.*]] = vector.broadcast %[[ARG4]] : index to vector<1x4x5x6xindex>
+// CHECK: %[[VAL12:.*]] = arith.addi %[[VAL11]], %[[VAL10]] : vector<1x4x5x6xindex>
+// CHECK: %[[VAL13:.*]] = vector.broadcast %[[ARG3]] : index to vector<6xindex>
+// CHECK: %[[VAL14:.*]] = arith.addi %[[VAL13]], %[[CST1]] : vector<6xindex>
+// CHECK: %[[VAL15:.*]] = vector.broadcast %[[ARG4]] : index to vector<6xindex>
+// CHECK: %[[VAL16:.*]] = arith.addi %[[VAL15]], %[[VAL14]] : vector<6xindex>
+// CHECK: %[[VAL17:.*]] = vector.extract %[[VAL8]][0, 0, 0, 0] : index from vector<1x4x5x6xindex>
+// CHECK: %[[VAL18:.*]] = vector.extract %[[VAL12]][0, 0, 0, 0] : index from vector<1x4x5x6xindex>
+// CHECK: %[[VAL19:.*]] = vector.extract %[[VAL16]][0] : index from vector<6xindex>
+// CHECK: %[[VAL20:.*]] = vector.transfer_read %[[ARG0]][%[[VAL4]], %[[VAL17]], %[[VAL18]], %[[VAL19]]], %[[CST2]] {in_bounds = [true, true, true, true]} : tensor<45x80x16x17xf32>, vector<1x4x5x6xf32>
+// CHECK: %[[VAL21:.*]] = vector.transfer_write %[[VAL20]], %[[ARG5]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true, true]} : vector<1x4x5x6xf32>, tensor<1x4x5x6xf32>
+// CHECK: return %[[VAL21]] : tensor<1x4x5x6xf32>
+// CHECK: }
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
>From 915ab930a399513ffaccd2062f3128652188491d Mon Sep 17 00:00:00 2001
From: meshtag <prathameshtagore at gmail.com>
Date: Wed, 27 Dec 2023 12:08:42 +0000
Subject: [PATCH 3/5] Update comments
---
.../Linalg/Transforms/Vectorization.cpp | 37 +++++++++++++++++--
.../Linalg/vectorize-tensor-extract.mlir | 2 +-
2 files changed, 35 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 0d0b1ef0d085df..9e231d0e3d394e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -836,9 +836,20 @@ static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
return result;
}
-// Determine if the val is obtained from a linalg.index op for the dimension at
-// which it is used to extract a value from the tensor and if it could be used
-// for contigous memory access.
+// Determine if the \p val is obtained from a linalg.index op for the dimension
+// at which it is used to extract a value from the tensor and if it could be
+// used for contigous memory access. For example:
+// %val1 = lingalg.index 0 : index
+// %e1 = tensor.extract %arg0[%val1, ..] : tensor<3x3xf32> would return true
+// while the following situation
+// %val1 = lingalg.index 0 : index
+// %e1 = tensor.extract %arg0[.., %val1] : tensor<3x3xf32> would return false
+// TODO: Relax this requirement to cover cases for contiguous access where inner
+// dimensions of the tensor vary using linalg.generic while outer dimensions are
+// kept constant.
+// Ex. %c0 = arith.constant 0 : index
+// %val1 = linalg.index 0 : index
+// %e1 = tensor.extract %arg0[%c0, %val1] : tensor<3x3xf32>
static bool isProperLinalgIdx(LinalgOp &linalgOp, Value &val,
uint64_t valuePosInExtract) {
auto targetShape = linalgOp.getStaticLoopRanges();
@@ -906,6 +917,10 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
if (linalgOp.hasDynamicShape())
return VectorMemoryAccessKind::Gather;
+ // 1. Assume that it's a gather load when reading _into_ a 1-D vector with the
+ // trailing dim equal 1, e.g. `tensor<1x4x1xi32`.
+ // TODO: Relax this condition.
+ // FIXME: This condition assumes non-dynamic sizes.
if (targetShape.back() == 1)
return VectorMemoryAccessKind::Gather;
@@ -918,6 +933,9 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
bool isLoopInvariantLoad = true;
bool isProperLinalgIdxLoad = true;
+ // 3. Analyze the indices of `extractOp`.
+ // Look at the way each index is calculated and decide whether it is suitable
+ // for a contiguous load.
auto indices = extractOp.getIndices();
for (auto [i, indexVal] : llvm::enumerate(indices)) {
if (inputShape.getShape()[i] == 1)
@@ -928,6 +946,12 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
? isProperLinalgIdx(linalgOp, indexVal, i)
: isProperLinalgIdxLoad;
+ // 4a. The load can't be scalar broadcast or contiguous if one of the
+ // indices is not
+ // i. Loop invariant
+ // ii. Not obtained from a linalg.index with its dimension attribute
+ // same as the dimension at which this indice was used in
+ // `extractOp`.
if (!isLoopInvariantLoad && !isProperLinalgIdxLoad) {
LDBG("Found gather load: " << extractOp);
return VectorMemoryAccessKind::Gather;
@@ -935,9 +959,16 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
}
if (isLoopInvariantLoad) {
+ // 4b. It is a scalar broadcast load if all indices of `extractOp` are loop
+ // invariant.
LDBG("Found scalar broadcast load: " << extractOp);
return VectorMemoryAccessKind::ScalarBroadcast;
} else if (!isLoopInvariantLoad && isProperLinalgIdxLoad) {
+ // 4c. It is a contiguous load if
+ // i. All indices which are not loop invariant and
+ // ii. They are obtained from `linalg.index` ops with their dimension
+ // attributes same as the dimension at which those indices are used in
+ // `extractOp`.
LDBG("Found contigous load: " << extractOp);
return VectorMemoryAccessKind::Contiguous;
}
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index 0ac67ca6af6ca7..efe61abcea792e 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -645,7 +645,7 @@ module attributes {transform.with_named_sequence} {
}
}
- // -----
+// -----
func.func @vectorize_nd_tensor_extract_contigous_complex(%6: tensor<45x80x16x17xf32>, %arg0: index, %arg1: index, %arg2: index, %arg3: index, %extracted_slice : tensor<1x4x5x6xf32>) -> tensor<1x4x5x6xf32> {
%0 = linalg.generic {
>From 80dc0631f367bf9dcebcd0f26f1f16ded86d9301 Mon Sep 17 00:00:00 2001
From: meshtag <prathameshtagore at gmail.com>
Date: Wed, 27 Dec 2023 15:26:35 +0000
Subject: [PATCH 4/5] Add more tests
---
.../Linalg/vectorize-tensor-extract.mlir | 52 +++++++++++++++----
1 file changed, 43 insertions(+), 9 deletions(-)
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index efe61abcea792e..2515ad4fd9c07d 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -613,11 +613,11 @@ module attributes {transform.with_named_sequence} {
// -----
-func.func @vectorize_nd_tensor_extract_gather(%arg0: tensor<80x16x17x18x19xf32>, %extracted_slice : tensor<4x5x6x7x8xf32>) -> tensor<4x5x6x7x8xf32> {
+func.func @vectorize_nd_tensor_extract_gather(%arg0: tensor<80x16x17x18x19xf32>, %extracted_slice : tensor<4x5x6x7xf32>) -> tensor<4x5x6x7xf32> {
%1 = linalg.generic {
- indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>],
- iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]
- } outs(%extracted_slice : tensor<4x5x6x7x8xf32>) {
+ indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
+ iterator_types = ["parallel", "parallel", "parallel", "parallel"]
+ } outs(%extracted_slice : tensor<4x5x6x7xf32>) {
^bb0(%out: f32):
%2 = linalg.index 0 : index
%3 = linalg.index 1 : index
@@ -625,16 +625,50 @@ func.func @vectorize_nd_tensor_extract_gather(%arg0: tensor<80x16x17x18x19xf32>,
%5 = linalg.index 3 : index
%extracted = tensor.extract %arg0[%2, %3, %4, %5, %5] : tensor<80x16x17x18x19xf32>
linalg.yield %extracted : f32
- } -> tensor<4x5x6x7x8xf32>
- return %1 : tensor<4x5x6x7x8xf32>
+ } -> tensor<4x5x6x7xf32>
+ return %1 : tensor<4x5x6x7xf32>
}
// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_gather(
-// CHECK-SAME: %[[ARG0:.*]]: tensor<80x16x17x18x19xf32>, %[[ARG1:.*]]: tensor<4x5x6x7x8xf32>) -> tensor<4x5x6x7x8xf32> {
+// CHECK-SAME: %[[ARG0:.*]]: tensor<80x16x17x18x19xf32>, %[[ARG1:.*]]: tensor<4x5x6x7xf32>) -> tensor<4x5x6x7xf32> {
// CHECK: %[[CST:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]]] [%{{.*}}], %{{.*}}, %{{.*}} : tensor<80x16x17x18x19xf32>, vector<4x5x6x7x8xindex>, vector<4x5x6x7x8xi1>, vector<4x5x6x7x8xf32> into vector<4x5x6x7x8xf32>
-// CHECK: %{{.*}} = vector.transfer_write %[[VAL]], %[[ARG1]][%[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true, true, true]} : vector<4x5x6x7x8xf32>, tensor<4x5x6x7x8xf32>
+// CHECK: %[[VAL:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]]] [%{{.*}}], %{{.*}}, %{{.*}} : tensor<80x16x17x18x19xf32>, vector<4x5x6x7xindex>, vector<4x5x6x7xi1>, vector<4x5x6x7xf32> into vector<4x5x6x7xf32>
+// CHECK: %{{.*}} = vector.transfer_write %[[VAL]], %[[ARG1]][%[[C0]], %[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true, true]} : vector<4x5x6x7xf32>, tensor<4x5x6x7xf32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 { vectorize_nd_extract } : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @vectorize_nd_tensor_extract_gather_constant_indices(%arg0: tensor<80x16x17x18x19xf32>, %extracted_slice : tensor<6x7x8xf32>) -> tensor<6x7x8xf32> {
+ %c5 = arith.constant 5 : index
+ %1 = linalg.generic {
+ indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
+ iterator_types = ["parallel", "parallel", "parallel"]
+ } outs(%extracted_slice : tensor<6x7x8xf32>) {
+ ^bb0(%out: f32):
+ %2 = linalg.index 0 : index
+ %3 = linalg.index 1 : index
+ %4 = linalg.index 2 : index
+ %extracted = tensor.extract %arg0[%c5, %c5, %2, %3, %4] : tensor<80x16x17x18x19xf32>
+ linalg.yield %extracted : f32
+ } -> tensor<6x7x8xf32>
+ return %1 : tensor<6x7x8xf32>
+}
+
+// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_gather_constant_indices(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<80x16x17x18x19xf32>, %[[ARG1:.*]]: tensor<6x7x8xf32>) -> tensor<6x7x8xf32> {
+// CHECK: %[[CST:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5]> : vector<6xindex>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]], %[[C0]], %[[C0]]] [%{{.*}}], %{{.*}}, %{{.*}} : tensor<80x16x17x18x19xf32>, vector<6x7x8xindex>, vector<6x7x8xi1>, vector<6x7x8xf32> into vector<6x7x8xf32>
+// CHECK: %{{.*}} = vector.transfer_write %[[VAL]], %[[ARG1]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<6x7x8xf32>, tensor<6x7x8xf32>
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
>From 8e935e617bc5a05a0cad6d7504351e1f3758cccb Mon Sep 17 00:00:00 2001
From: Prathamesh Tagore <prathameshtagore at gmail.com>
Date: Mon, 1 Jan 2024 13:54:09 +0530
Subject: [PATCH 5/5] Improve comments
---
.../Linalg/Transforms/Vectorization.cpp | 23 ++++++++++++-------
1 file changed, 15 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 9e231d0e3d394e..12ef18f76aba70 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -814,7 +814,7 @@ static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
return indexOp.getDim() != trailingLoopDim;
}
- // val will be loop variant in some of the other cases.
+ // val will be loop variant in most other cases.
// TODO: Relax this condition
return false;
}
@@ -841,7 +841,7 @@ static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
// used for contigous memory access. For example:
// %val1 = lingalg.index 0 : index
// %e1 = tensor.extract %arg0[%val1, ..] : tensor<3x3xf32> would return true
-// while the following situation
+// while the following case
// %val1 = lingalg.index 0 : index
// %e1 = tensor.extract %arg0[.., %val1] : tensor<3x3xf32> would return false
// TODO: Relax this requirement to cover cases for contiguous access where inner
@@ -867,12 +867,18 @@ static bool isProperLinalgIdx(LinalgOp &linalgOp, Value &val,
if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
// If target shape is of the form 1x1x1x..xn and val is obtained from a
- // linalg.index op, it will be used for contiguous access only when it is
+ // linalg.index op, it can be used for contiguous access only when it is
// obtained for the trailing dimension.
if (llvm::count_if(targetShape,
[](int64_t dimSize) { return dimSize > 1; }) == 1 &&
targetShape.back() != 1) {
auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
+
+ // This is special handling of the case when n dimensional tensor is
+ // accessed like [p, p's, p, c, c's, c, idx_for_trailing_loop_dim]
+ // where:
+ // p = properLinalgIdx
+ // c = loopInvariantIdx
return indexOp.getDim() == trailingLoopDim;
}
return indexOp.getDim() == valuePosInExtract;
@@ -918,7 +924,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
return VectorMemoryAccessKind::Gather;
// 1. Assume that it's a gather load when reading _into_ a 1-D vector with the
- // trailing dim equal 1, e.g. `tensor<1x4x1xi32`.
+ // trailing dim equal 1, e.g. `vector<1x4x1xi32>`.
// TODO: Relax this condition.
// FIXME: This condition assumes non-dynamic sizes.
if (targetShape.back() == 1)
@@ -927,6 +933,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
// 2. Assume that it's a gather load when reading _from_ a tensor for which
// the trailing dimension is 1, e.g. `tensor<1x4x1xi32>`.
// TODO: Relax this condition.
+ // FIXME: This condition assumes non-dynamic sizes.
if (inputShape.getShape().back() == 1)
return VectorMemoryAccessKind::Gather;
@@ -965,10 +972,10 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
return VectorMemoryAccessKind::ScalarBroadcast;
} else if (!isLoopInvariantLoad && isProperLinalgIdxLoad) {
// 4c. It is a contiguous load if
- // i. All indices which are not loop invariant and
- // ii. They are obtained from `linalg.index` ops with their dimension
- // attributes same as the dimension at which those indices are used in
- // `extractOp`.
+ // i. Some indices are not loop invariant and
+ // ii. Loop variant indices are obtained from `linalg.index` ops with
+ // their dimension attributes same as the dimension at which those
+ // indices are used in `extractOp`.
LDBG("Found contigous load: " << extractOp);
return VectorMemoryAccessKind::Contiguous;
}
More information about the Mlir-commits
mailing list