[Mlir-commits] [mlir] [Linalg][Vectorization] Add support for linalg vectorization of a tensor.extract case (PR #107922)
Nirvedh Meshram
llvmlistbot at llvm.org
Fri Sep 20 09:03:23 PDT 2024
https://github.com/nirvedhmeshram updated https://github.com/llvm/llvm-project/pull/107922
>From b88f8e389886cc521c9e07b05bb0535661dbd128 Mon Sep 17 00:00:00 2001
From: Nirvedh Meshram <nirvedh at gmail.com>
Date: Mon, 9 Sep 2024 21:04:05 +0000
Subject: [PATCH 1/2] [Linalg][Vectorization] Add support for linalg
vectorization case with outer non unit dim
---
.../Linalg/Transforms/Vectorization.cpp | 26 ++++++++--
.../Linalg/vectorize-tensor-extract.mlir | 52 +++++++++++++++++++
2 files changed, 74 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index a376afa5ddab12..f8c64d3bf12ff6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -810,6 +810,25 @@ static Value calculateGatherOffset(RewriterBase &rewriter,
enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather };
+
+/// Find the non constant dim in a linalgOp. This is used for finding contiguous
+/// loads and it is expected that only one dim will be non constant, if thats
+/// not the case this function will assert.
+static uint64_t getNonUnitLoopDim(LinalgOp linalgOp) {
+ uint64_t nonUnitDim = 0;
+ uint64_t countNonUnitDim = 0;
+ for (auto tripCount : llvm::enumerate(linalgOp.getStaticLoopRanges())) {
+ if (tripCount.value() != 1) {
+ nonUnitDim = tripCount.index();
+ countNonUnitDim++;
+ }
+ }
+ assert(countNonUnitDim == 1 &&
+ "Expected only one non unit loop dim in this linalg op");
+ return nonUnitDim;
+}
+
+
/// Checks whether `val` can be used for calculating a loop invariant index.
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val,
VectorType resType) {
@@ -889,11 +908,10 @@ static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
Operation *defOp = val.getDefiningOp();
assert(defOp && "This is neither a block argument nor an operation result");
- // Given the assumption on the loop ranges above, only the trailing loop
- // index is not constant.
- auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
+ auto nonUnitLoopDim = getNonUnitLoopDim(linalgOp);
+
if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
- foundIndexOp = (indexOp.getDim() == trailingLoopDim);
+ foundIndexOp = (indexOp.getDim() == nonUnitLoopDim);
return true;
}
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index bdaa20c3bf971e..e595e529fdcfa2 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -253,6 +253,58 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1, d2) -> (d0 + d1 + d2)>
+func.func @vectorize_nd_tensor_extract_without_outer_unit_dim(%arg0: tensor<8x128x768xf32>, %arg1 : index) -> tensor<8x1xf32> {
+ %c0 = arith.constant 0 : index
+ %0 = tensor.empty() : tensor<8x1xf32>
+ %1 = linalg.generic {
+ indexing_maps = [#map],
+ iterator_types = ["parallel", "parallel"]
+ } outs(%0 : tensor<8x1xf32>) {
+ ^bb0(%arg5: f32):
+ %2 = linalg.index 0 : index
+ %3 = linalg.index 1 : index
+ %4 = affine.apply #map1(%arg1, %3, %arg1)
+ %extracted = tensor.extract %arg0[%2, %c0, %4] : tensor<8x128x768xf32>
+ linalg.yield %extracted : f32
+ } -> tensor<8x1xf32>
+ return %1 : tensor<8x1xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg2: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg2 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 {vectorize_nd_extract} : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_without_outer_unit_dim
+// CHECK-SAME: %[[ARG0:.*]]: tensor<8x128x768xf32>
+// CHECK-SAME: %[[ARG1:.*]]: index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[CST:.*]] = arith.constant dense<768> : vector<1x8xindex>
+// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<128> : vector<1x8xindex>
+// CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32>
+// CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<true> : vector<8x1xi1>
+// CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
+// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<8x1xf32>
+// CHECK: %[[B1:.*]] = vector.broadcast %[[CST_3]] : vector<8xindex> to vector<1x8xindex>
+// CHECK: %[[ADDI_ARG1:.*]] = arith.addi %[[ARG1]], %[[ARG1]] : index
+// CHECK: %[[B2:.*]] = vector.broadcast %[[ADDI_ARG1]] : index to vector<1xindex>
+// CHECK: %[[MULI_1:.*]] = arith.muli %[[B1]], %[[CST_0]] : vector<1x8xindex>
+// CHECK: %[[MULI_2:.*]] = arith.muli %[[MULI_1]], %[[CST]] : vector<1x8xindex>
+// CHECK: %[[T:.*]] = vector.transpose %[[MULI_2]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
+// CHECK: %[[B3:.*]] = vector.broadcast %[[B2]] : vector<1xindex> to vector<8x1xindex>
+// CHECK: %[[ADDI:.*]] = arith.addi %[[B3]], %[[T]] : vector<8x1xindex>
+// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[ADDI]]], %[[CST_2]], %[[PASSTHRU]] : tensor<8x128x768xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32>
+// CHECK: vector.transfer_write %[[GATHER]], %[[EMPTY]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32>
+
// -----
#map = affine_map<(d0) -> (d0)>
>From 95189eb7c937e6cfd5d71f65e03158bae2d8039d Mon Sep 17 00:00:00 2001
From: Nirvedh Meshram <nirvedh at gmail.com>
Date: Fri, 20 Sep 2024 15:58:15 +0000
Subject: [PATCH 2/2] address reviwer comments
---
.../Linalg/Transforms/Vectorization.cpp | 19 ++++++++++++-------
.../Linalg/vectorize-tensor-extract.mlir | 8 ++++----
2 files changed, 16 insertions(+), 11 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index f8c64d3bf12ff6..2244f29967dfe4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -810,10 +810,12 @@ static Value calculateGatherOffset(RewriterBase &rewriter,
enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather };
-
-/// Find the non constant dim in a linalgOp. This is used for finding contiguous
-/// loads and it is expected that only one dim will be non constant, if thats
-/// not the case this function will assert.
+/// Find the non-unit dim in a linalgOp.
+/// When executing this hook, it is expected that only one dim will be non-unit.
+/// Other cases (i.e. reading n-D vectors) should've been labelled as gather
+/// loads before calling this method. This is used for finding contiguous loads
+/// (represented as `tensor.extract`) within `linalg.generic` Ops. Note that
+/// this condition is expected to hold for statically shaped Linalg Ops only.
static uint64_t getNonUnitLoopDim(LinalgOp linalgOp) {
uint64_t nonUnitDim = 0;
uint64_t countNonUnitDim = 0;
@@ -823,12 +825,13 @@ static uint64_t getNonUnitLoopDim(LinalgOp linalgOp) {
countNonUnitDim++;
}
}
- assert(countNonUnitDim == 1 &&
- "Expected only one non unit loop dim in this linalg op");
+
+ assert(linalgOp.hasDynamicShape() ||
+ countNonUnitDim == 1 && "For statically shaped Linalg Ops, only one "
+ "non-unit loop dim is expected");
return nonUnitDim;
}
-
/// Checks whether `val` can be used for calculating a loop invariant index.
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val,
VectorType resType) {
@@ -908,6 +911,8 @@ static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
Operation *defOp = val.getDefiningOp();
assert(defOp && "This is neither a block argument nor an operation result");
+ // Given the assumption on the loop ranges above, we expect only 1 non-unit
+ // loop dim.
auto nonUnitLoopDim = getNonUnitLoopDim(linalgOp);
if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index e595e529fdcfa2..ad3a8d9f926082 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -258,7 +258,7 @@ module attributes {transform.with_named_sequence} {
#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1, d2) -> (d0 + d1 + d2)>
-func.func @vectorize_nd_tensor_extract_without_outer_unit_dim(%arg0: tensor<8x128x768xf32>, %arg1 : index) -> tensor<8x1xf32> {
+func.func @vectorize_nd_tensor_extract_load_1d_column_vector_using_gather_load(%arg0: tensor<8x128x768xf32>, %arg1 : index) -> tensor<8x1xf32> {
%c0 = arith.constant 0 : index
%0 = tensor.empty() : tensor<8x1xf32>
%1 = linalg.generic {
@@ -276,15 +276,15 @@ func.func @vectorize_nd_tensor_extract_without_outer_unit_dim(%arg0: tensor<8x12
}
module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%arg2: !transform.any_op {transform.readonly}) {
- %0 = transform.structured.match ops{["linalg.generic"]} in %arg2 : (!transform.any_op) -> !transform.any_op
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
%2 = transform.structured.vectorize_children_and_apply_patterns %1 {vectorize_nd_extract} : (!transform.any_op) -> !transform.any_op
transform.yield
}
}
-// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_without_outer_unit_dim
+// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_load_1d_column_vector_using_gather_load
// CHECK-SAME: %[[ARG0:.*]]: tensor<8x128x768xf32>
// CHECK-SAME: %[[ARG1:.*]]: index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
More information about the Mlir-commits
mailing list