[llvm-branch-commits] [mlir] b102575 - Revert "[mlir][linalg] Relax tensor.extract vectorization (#99299)"
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Tue Aug 6 14:28:40 PDT 2024
Author: Han-Chung Wang
Date: 2024-08-06T14:28:37-07:00
New Revision: b102575a6cf350124a8967a4e0714718008f72c1
URL: https://github.com/llvm/llvm-project/commit/b102575a6cf350124a8967a4e0714718008f72c1
DIFF: https://github.com/llvm/llvm-project/commit/b102575a6cf350124a8967a4e0714718008f72c1.diff
LOG: Revert "[mlir][linalg] Relax tensor.extract vectorization (#99299)"
This reverts commit 8868c02cda875d1efe1646affa01656ef268ffed.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 6da886f5ec19e..3d0d6abf702d7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -946,22 +946,27 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
if (linalgOp.hasDynamicShape())
return VectorMemoryAccessKind::Gather;
- // True for vectors that are effectively 1D, e.g. `vector<1x4x1xi32>`, false
- // otherwise.
- bool isOutput1DVector = (llvm::count_if(targetShape, [](int64_t dimSize) {
- return dimSize > 1;
- }) == 1);
- // 1. Assume that it's a gather load when reading non-1D vector.
- if (!isOutput1DVector)
+ // 1. Assume that it's a gather load when reading _into_:
+ // * an n-D "vector", like `tensor<1x2x4xi32` or `tensor<2x1x4xi32>`, or
+ // * a 1-D "vector" with the trailing dim equal 1, e.g. `tensor<1x4x1xi32`.
+ // TODO: Relax these conditions.
+ // FIXME: This condition assumes non-dynamic sizes.
+ if ((llvm::count_if(targetShape,
+ [](int64_t dimSize) { return dimSize > 1; }) != 1) ||
+ targetShape.back() == 1)
+ return VectorMemoryAccessKind::Gather;
+ // 2. Assume that it's a gather load when reading _from_ a tensor for which
+ // the trailing dimension is 1, e.g. `tensor<1x4x1xi32>`.
+ // TODO: Relax this condition.
+ if (inputShape.getShape().back() == 1)
return VectorMemoryAccessKind::Gather;
bool leadingIdxsLoopInvariant = true;
- // 2. Analyze the leading indices of `extractOp`.
+ // 3. Analyze the leading indices of `extractOp`.
// Look at the way each index is calculated and decide whether it is suitable
- // for a contiguous load, i.e. whether it's loop invariant. If not, it's a
- // gather load.
+ // for a contiguous load, i.e. whether it's loop invariant.
auto indices = extractOp.getIndices();
auto leadIndices = indices.drop_back(1);
@@ -977,13 +982,13 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
return VectorMemoryAccessKind::Gather;
- // 3. Analyze the trailing index for `extractOp`.
+ // 4. Analyze the trailing index for `extractOp`.
// At this point we know that the leading indices are loop invariant. This
// means that is potentially a scalar or a contiguous load. We can decide
// based on the trailing idx.
auto extractOpTrailingIdx = indices.back();
- // 3a. Scalar broadcast load
+ // 4a. Scalar broadcast load
// If the trailing index is loop invariant then this is a scalar load.
if (leadingIdxsLoopInvariant &&
isLoopInvariantIdx(linalgOp, extractOpTrailingIdx)) {
@@ -992,7 +997,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
return VectorMemoryAccessKind::ScalarBroadcast;
- // 3b. Contiguous loads
+ // 4b. Contiguous loads
// The trailing `extractOp` index should increment with every loop iteration.
// This effectively means that it must be based on the trailing loop index.
// This is what the following bool captures.
@@ -1006,7 +1011,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
return VectorMemoryAccessKind::Contiguous;
- // 4. Fallback case - gather load.
+ // 5. Fallback case - gather load.
LDBG("Found gather load: " << extractOp);
return VectorMemoryAccessKind::Gather;
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index ac75a19cbeb28..85e1c56dd45a0 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -595,59 +595,3 @@ module attributes {transform.with_named_sequence} {
-// -----
-func.func @vectorize_scalar_broadcast_column_tensor(%in: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> {
- %c4 = arith.constant 4 : index
- %c0 = arith.constant 0 : index
- %cst = arith.constant dense<[[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14]]> : tensor<15x1xi32>
- %out = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} outs(%in : tensor<1x1x4xi32>) {
- ^bb0(%out: i32):
- %8 = linalg.index 0 : index
- %idx_0 = linalg.index 0 : index
- %extracted = tensor.extract %cst[%idx_0, %c0] : tensor<15x1xi32>
- linalg.yield %extracted : i32
- } -> tensor<1x1x4xi32>
- return %out:tensor<1x1x4xi32>
-// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1) -> (0, 0, 0)>
-// CHECK-LABEL: func.func @vectorize_scalar_broadcast_column_tensor(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> {
-// CHECK: %[[VAL_1:.*]] = arith.constant 4 : index
-// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_3:.*]] = arith.constant dense<{{\[\[}}0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14]]> : tensor<15x1xi32>
-// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
-// CHECK: %[[VAL_5:.*]] = arith.constant 1 : index
-// CHECK: %[[VAL_6:.*]] = arith.constant 4 : index
-// CHECK: %[[VAL_7:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_8:.*]] = arith.constant 0 : i32
-// CHECK: %[[VAL_9:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_7]], %[[VAL_7]], %[[VAL_7]]], %[[VAL_8]] : tensor<1x1x4xi32>, vector<1x1x4xi32>
-// CHECK: %[[VAL_10:.*]] = vector.step : vector<1xindex>
-// CHECK: %[[VAL_11:.*]] = vector.broadcast %[[VAL_10]] : vector<1xindex> to vector<4x1x1xindex>
-// CHECK: %[[VAL_12:.*]] = vector.transpose %[[VAL_11]], [2, 1, 0] : vector<4x1x1xindex> to vector<1x1x4xindex>
-// CHECK: %[[VAL_13:.*]] = vector.step : vector<1xindex>
-// CHECK: %[[VAL_14:.*]] = vector.broadcast %[[VAL_13]] : vector<1xindex> to vector<4x1x1xindex>
-// CHECK: %[[VAL_15:.*]] = vector.transpose %[[VAL_14]], [2, 1, 0] : vector<4x1x1xindex> to vector<1x1x4xindex>
-// CHECK: %[[VAL_16:.*]] = arith.constant dense<true> : vector<1x1x4xi1>
-// CHECK: %[[VAL_17:.*]] = arith.constant dense<0> : vector<1x1x4xi32>
-// CHECK: %[[VAL_18:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_19:.*]] = arith.constant 0 : i32
-// CHECK: %[[VAL_20:.*]] = vector.shape_cast %[[VAL_15]] : vector<1x1x4xindex> to vector<4xindex>
-// CHECK: %[[VAL_21:.*]] = vector.extractelement %[[VAL_20]]{{\[}}%[[VAL_19]] : i32] : vector<4xindex>
-// CHECK: %[[VAL_22:.*]] = arith.constant 0 : i32
-// CHECK: %[[VAL_23:.*]] = vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_21]], %[[VAL_2]]], %[[VAL_22]] {in_bounds = [true, true, true], permutation_map = #[[$ATTR_1]]} : tensor<15x1xi32>, vector<1x1x4xi32>
-// CHECK: %[[VAL_24:.*]] = arith.constant 0 : index
-// CHECK: %[[VAL_25:.*]] = vector.transfer_write %[[VAL_23]], %[[VAL_0]]{{\[}}%[[VAL_24]], %[[VAL_24]], %[[VAL_24]]] : vector<1x1x4xi32>, tensor<1x1x4xi32>
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- transform.structured.vectorize %0 vector_sizes [1, 1, 4]{ vectorize_nd_extract } : !transform.any_op
- transform.yield
- }
More information about the llvm-branch-commits
mailing list