[Mlir-commits] [mlir] [Linalg][Vectorization] Add support for linalg vectorization of a tensor.extract case (PR #107922)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Sep 9 14:53:15 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Nirvedh Meshram (nirvedhmeshram)

<details>
<summary>Changes</summary>

There is a case shown in https://github.com/llvm/llvm-project/issues/107476 that the current vectorization patterns cant handle. This PR provides a way of handling it by adding an extra tranpose op which showed get canceled with the existing transpose.
Fixes: https://github.com/llvm/llvm-project/issues/107476

---
Full diff: https://github.com/llvm/llvm-project/pull/107922.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+25) 
- (modified) mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir (+48) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 63dcda78d0f2be..16d1b1d6e0d0d4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1079,6 +1079,31 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
       continue;
     }
 
+    auto idxType = dyn_cast<VectorType>(idx.getType());
+
+    if (idxType && idxType.getShape().size() == resultType.getShape().size()) {
+      auto maxElement = std::max_element(resultType.getShape().begin(),
+                                         resultType.getShape().end());
+      auto maxElementDim =
+          std::distance(resultType.getShape().begin(), maxElement);
+      // This means that the result type of the index is non trailing and we
+      // insert transpose op in this case to match it to the extract type.
+      if (maxElementDim != resultType.getShape().size() - 1) {
+        SmallVector<int64_t> transposition = llvm::to_vector<16>(
+            llvm::seq<int64_t>(0, resultType.getShape().size()));
+        std::swap(transposition.back(), transposition[maxElementDim]);
+        auto transposeOp =
+            rewriter.create<vector::TransposeOp>(loc, idx, transposition);
+        auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
+            loc,
+            VectorType::get(*maxElement, rewriter.getIndexType(),
+                            resultType.getScalableDims().back()),
+            transposeOp);
+        transferReadIdxs.push_back(rewriter.create<vector::ExtractElementOp>(
+            loc, indexAs1dVector, zero));
+        continue;
+      }
+    }
     auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
         loc,
         VectorType::get(resultType.getShape().back(), rewriter.getIndexType(),
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index bdaa20c3bf971e..b66a0c4e4093b0 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -253,6 +253,54 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1, d2) -> (d0 + d1 + d2)>
+func.func @vectorize_nd_tensor_extract_transfer_without_outer_unit_dim(%arg0: tensor<8x128x768xf32>, %arg1 : index) -> tensor<8x1xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = tensor.empty() : tensor<8x1xf32>
+  %1 = linalg.generic {
+    indexing_maps = [#map], 
+    iterator_types = ["parallel", "parallel"]
+  } outs(%0 : tensor<8x1xf32>) {
+  ^bb0(%arg5: f32):
+      %2 = linalg.index 0 : index
+      %3 = linalg.index 1 : index
+      %4 = affine.apply #map1(%arg1, %3, %arg1)
+    %extracted = tensor.extract %arg0[%2, %c0, %4] : tensor<8x128x768xf32>
+    linalg.yield %extracted : f32
+  } -> tensor<8x1xf32>
+  return %1 : tensor<8x1xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg2: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg2 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.vectorize_children_and_apply_patterns %1 {vectorize_nd_extract} : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_without_outer_unit_dim
+// CHECK-SAME: %[[ARG0:.*]]: tensor<8x128x768xf32>
+// CHECK-SAME: %[[ARG1:.*]]: index
+// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
+// CHECK-DAG: %[[C0_i32:.*]] = arith.constant 0 : i32
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:     %[[IDX0:.*]] = tensor.empty() : tensor<8x1xf32>
+// CHECK:     %[[IDX1:.*]] = vector.broadcast %[[CST_0]] : vector<8xindex> to vector<1x8xindex
+// CHECK:    %[[IDX2:.*]] = vector.transpose %[[IDX1]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
+// CHECK:    %[[IDX3:.*]] = arith.addi %[[ARG1]], %[[ARG1]] : index
+// CHECK:    %[[IDX4:.*]] = vector.transpose %[[IDX2]], [1, 0] : vector<8x1xindex> to vector<1x8xindex>
+// CHECK:    %[[IDX5:.*]] = vector.shape_cast %[[IDX4]] : vector<1x8xindex> to vector<8xindex>
+// CHECK:    %[[IDX6:.*]] = vector.extractelement %[[IDX5]][%[[C0_i32]] : i32] : vector<8xindex>
+// CHECK:    %[[IDX7:.*]] = vector.transfer_read %[[ARG0]][%[[IDX6]], %[[C0]], %[[IDX3]]], %[[CST]] {in_bounds = [true, true]} : tensor<8x128x768xf32>, vector<8x1xf32>
+// CHECK:    vector.transfer_write %[[IDX7]], %[[IDX0]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32>
+
 // -----
 
 #map = affine_map<(d0) -> (d0)>

``````````

</details>


https://github.com/llvm/llvm-project/pull/107922


More information about the Mlir-commits mailing list