[Mlir-commits] [mlir] [Linalg][Vectorization] Add support for linalg vectorization of a tensor.extract case (PR #107922)

Nirvedh Meshram llvmlistbot at llvm.org
Wed Sep 18 09:12:22 PDT 2024


https://github.com/nirvedhmeshram updated https://github.com/llvm/llvm-project/pull/107922

>From 6c6ed797dd191d646826e678068d29a4dce4c72e Mon Sep 17 00:00:00 2001
From: Nirvedh Meshram <nirvedh at gmail.com>
Date: Mon, 9 Sep 2024 21:04:05 +0000
Subject: [PATCH 1/2] [Linalg][Vectorization] Add support for linalg
 vectorization case with outer non unit dim

---
 .../Linalg/Transforms/Vectorization.cpp       | 21 ++++++--
 .../Linalg/vectorize-tensor-extract.mlir      | 52 +++++++++++++++++++
 2 files changed, 69 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 63dcda78d0f2be..5b709267b63cbd 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -810,6 +810,20 @@ static Value calculateGatherOffset(RewriterBase &rewriter,
 
 enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather };
 
+static uint64_t getNonUnitLoopDim(LinalgOp linalgOp) {
+  uint64_t nonUnitDim = 0;
+  uint64_t countNonUnitDim = 0;
+  for (auto tripCount : llvm::enumerate(linalgOp.getStaticLoopRanges())) {
+    if (tripCount.value() != 1) {
+      nonUnitDim = tripCount.index();
+      countNonUnitDim++;
+    }
+  }
+  assert(countNonUnitDim == 1 &&
+         "Expected only one non unit loop dim in this linalg op");
+  return nonUnitDim;
+}
+
 /// Checks whether /p val can be used for calculating a loop invariant index.
 static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
 
@@ -890,11 +904,10 @@ static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
   Operation *defOp = val.getDefiningOp();
   assert(defOp && "This is neither a block argument nor an operation result");
 
-  // Given the assumption on the loop ranges above, only the trailing loop
-  // index is not constant.
-  auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
+  auto nonUnitLoopDim = getNonUnitLoopDim(linalgOp);
+
   if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
-    foundIndexOp = (indexOp.getDim() == trailingLoopDim);
+    foundIndexOp = (indexOp.getDim() == nonUnitLoopDim);
     return true;
   }
 
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index bdaa20c3bf971e..e595e529fdcfa2 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -253,6 +253,58 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1, d2) -> (d0 + d1 + d2)>
+func.func @vectorize_nd_tensor_extract_without_outer_unit_dim(%arg0: tensor<8x128x768xf32>, %arg1 : index) -> tensor<8x1xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = tensor.empty() : tensor<8x1xf32>
+  %1 = linalg.generic {
+    indexing_maps = [#map], 
+    iterator_types = ["parallel", "parallel"]
+  } outs(%0 : tensor<8x1xf32>) {
+  ^bb0(%arg5: f32):
+      %2 = linalg.index 0 : index
+      %3 = linalg.index 1 : index
+      %4 = affine.apply #map1(%arg1, %3, %arg1)
+    %extracted = tensor.extract %arg0[%2, %c0, %4] : tensor<8x128x768xf32>
+    linalg.yield %extracted : f32
+  } -> tensor<8x1xf32>
+  return %1 : tensor<8x1xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg2: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg2 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.vectorize_children_and_apply_patterns %1 {vectorize_nd_extract} : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_without_outer_unit_dim
+// CHECK-SAME: %[[ARG0:.*]]: tensor<8x128x768xf32>
+// CHECK-SAME: %[[ARG1:.*]]: index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[CST:.*]] = arith.constant dense<768> : vector<1x8xindex>
+// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<128> : vector<1x8xindex>
+// CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32>
+// CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<true> : vector<8x1xi1>
+// CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
+// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<8x1xf32>
+// CHECK: %[[B1:.*]] = vector.broadcast %[[CST_3]] : vector<8xindex> to vector<1x8xindex>
+// CHECK: %[[ADDI_ARG1:.*]] = arith.addi %[[ARG1]], %[[ARG1]] : index
+// CHECK: %[[B2:.*]] = vector.broadcast %[[ADDI_ARG1]] : index to vector<1xindex>
+// CHECK: %[[MULI_1:.*]] = arith.muli %[[B1]], %[[CST_0]] : vector<1x8xindex>
+// CHECK: %[[MULI_2:.*]] = arith.muli %[[MULI_1]], %[[CST]] : vector<1x8xindex>
+// CHECK: %[[T:.*]] = vector.transpose %[[MULI_2]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
+// CHECK: %[[B3:.*]] = vector.broadcast %[[B2]] : vector<1xindex> to vector<8x1xindex>
+// CHECK: %[[ADDI:.*]] = arith.addi %[[B3]], %[[T]] : vector<8x1xindex>
+// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[ADDI]]], %[[CST_2]], %[[PASSTHRU]] : tensor<8x128x768xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32>
+// CHECK: vector.transfer_write %[[GATHER]], %[[EMPTY]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32>
+
 // -----
 
 #map = affine_map<(d0) -> (d0)>

>From a567a545580acac98b1b0ed79973d6e6e35dff7f Mon Sep 17 00:00:00 2001
From: Nirvedh Meshram <nirvedh at gmail.com>
Date: Wed, 18 Sep 2024 16:11:16 +0000
Subject: [PATCH 2/2] add comment

---
 mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 5b709267b63cbd..d82b9583637498 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -810,6 +810,9 @@ static Value calculateGatherOffset(RewriterBase &rewriter,
 
 enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather };
 
+/// Find the non constant dim in a linalgOp. This is used for finding contiguous
+/// loads and it is expected that only one dim will be non constant, if thats
+/// not the case this function will assert.
 static uint64_t getNonUnitLoopDim(LinalgOp linalgOp) {
   uint64_t nonUnitDim = 0;
   uint64_t countNonUnitDim = 0;



More information about the Mlir-commits mailing list