[Mlir-commits] [mlir] [Linalg][Vectorization] Add support for linalg vectorization of a tensor.extract case (PR #107922)
Nirvedh Meshram
llvmlistbot at llvm.org
Wed Sep 18 09:12:22 PDT 2024
https://github.com/nirvedhmeshram updated https://github.com/llvm/llvm-project/pull/107922
>From 6c6ed797dd191d646826e678068d29a4dce4c72e Mon Sep 17 00:00:00 2001
From: Nirvedh Meshram <nirvedh at gmail.com>
Date: Mon, 9 Sep 2024 21:04:05 +0000
Subject: [PATCH 1/2] [Linalg][Vectorization] Add support for linalg
vectorization case with outer non unit dim
---
.../Linalg/Transforms/Vectorization.cpp | 21 ++++++--
.../Linalg/vectorize-tensor-extract.mlir | 52 +++++++++++++++++++
2 files changed, 69 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 63dcda78d0f2be..5b709267b63cbd 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -810,6 +810,20 @@ static Value calculateGatherOffset(RewriterBase &rewriter,
enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather };
+static uint64_t getNonUnitLoopDim(LinalgOp linalgOp) {
+ uint64_t nonUnitDim = 0;
+ uint64_t countNonUnitDim = 0;
+ for (auto tripCount : llvm::enumerate(linalgOp.getStaticLoopRanges())) {
+ if (tripCount.value() != 1) {
+ nonUnitDim = tripCount.index();
+ countNonUnitDim++;
+ }
+ }
+ assert(countNonUnitDim == 1 &&
+ "Expected only one non unit loop dim in this linalg op");
+ return nonUnitDim;
+}
+
/// Checks whether /p val can be used for calculating a loop invariant index.
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
@@ -890,11 +904,10 @@ static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
Operation *defOp = val.getDefiningOp();
assert(defOp && "This is neither a block argument nor an operation result");
- // Given the assumption on the loop ranges above, only the trailing loop
- // index is not constant.
- auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
+ auto nonUnitLoopDim = getNonUnitLoopDim(linalgOp);
+
if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
- foundIndexOp = (indexOp.getDim() == trailingLoopDim);
+ foundIndexOp = (indexOp.getDim() == nonUnitLoopDim);
return true;
}
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index bdaa20c3bf971e..e595e529fdcfa2 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -253,6 +253,58 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1, d2) -> (d0 + d1 + d2)>
+func.func @vectorize_nd_tensor_extract_without_outer_unit_dim(%arg0: tensor<8x128x768xf32>, %arg1 : index) -> tensor<8x1xf32> {
+ %c0 = arith.constant 0 : index
+ %0 = tensor.empty() : tensor<8x1xf32>
+ %1 = linalg.generic {
+ indexing_maps = [#map],
+ iterator_types = ["parallel", "parallel"]
+ } outs(%0 : tensor<8x1xf32>) {
+ ^bb0(%arg5: f32):
+ %2 = linalg.index 0 : index
+ %3 = linalg.index 1 : index
+ %4 = affine.apply #map1(%arg1, %3, %arg1)
+ %extracted = tensor.extract %arg0[%2, %c0, %4] : tensor<8x128x768xf32>
+ linalg.yield %extracted : f32
+ } -> tensor<8x1xf32>
+ return %1 : tensor<8x1xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg2: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg2 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 {vectorize_nd_extract} : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_without_outer_unit_dim
+// CHECK-SAME: %[[ARG0:.*]]: tensor<8x128x768xf32>
+// CHECK-SAME: %[[ARG1:.*]]: index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[CST:.*]] = arith.constant dense<768> : vector<1x8xindex>
+// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<128> : vector<1x8xindex>
+// CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32>
+// CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<true> : vector<8x1xi1>
+// CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
+// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<8x1xf32>
+// CHECK: %[[B1:.*]] = vector.broadcast %[[CST_3]] : vector<8xindex> to vector<1x8xindex>
+// CHECK: %[[ADDI_ARG1:.*]] = arith.addi %[[ARG1]], %[[ARG1]] : index
+// CHECK: %[[B2:.*]] = vector.broadcast %[[ADDI_ARG1]] : index to vector<1xindex>
+// CHECK: %[[MULI_1:.*]] = arith.muli %[[B1]], %[[CST_0]] : vector<1x8xindex>
+// CHECK: %[[MULI_2:.*]] = arith.muli %[[MULI_1]], %[[CST]] : vector<1x8xindex>
+// CHECK: %[[T:.*]] = vector.transpose %[[MULI_2]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
+// CHECK: %[[B3:.*]] = vector.broadcast %[[B2]] : vector<1xindex> to vector<8x1xindex>
+// CHECK: %[[ADDI:.*]] = arith.addi %[[B3]], %[[T]] : vector<8x1xindex>
+// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[ADDI]]], %[[CST_2]], %[[PASSTHRU]] : tensor<8x128x768xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32>
+// CHECK: vector.transfer_write %[[GATHER]], %[[EMPTY]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32>
+
// -----
#map = affine_map<(d0) -> (d0)>
>From a567a545580acac98b1b0ed79973d6e6e35dff7f Mon Sep 17 00:00:00 2001
From: Nirvedh Meshram <nirvedh at gmail.com>
Date: Wed, 18 Sep 2024 16:11:16 +0000
Subject: [PATCH 2/2] add comment
---
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp | 3 +++
1 file changed, 3 insertions(+)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 5b709267b63cbd..d82b9583637498 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -810,6 +810,9 @@ static Value calculateGatherOffset(RewriterBase &rewriter,
enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather };
+/// Find the non constant dim in a linalgOp. This is used for finding contiguous
+/// loads and it is expected that only one dim will be non constant, if thats
+/// not the case this function will assert.
static uint64_t getNonUnitLoopDim(LinalgOp linalgOp) {
uint64_t nonUnitDim = 0;
uint64_t countNonUnitDim = 0;
More information about the Mlir-commits
mailing list