[Mlir-commits] [mlir] e45fc51 - [Linalg][Vectorization] Add support for linalg vectorization of a tensor.extract case (#107922)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Sep 21 13:12:55 PDT 2024


Author: Nirvedh Meshram
Date: 2024-09-21T15:12:51-05:00
New Revision: e45fc5140df7bb60242a989ac7fc5cd0c0563234

URL: https://github.com/llvm/llvm-project/commit/e45fc5140df7bb60242a989ac7fc5cd0c0563234
DIFF: https://github.com/llvm/llvm-project/commit/e45fc5140df7bb60242a989ac7fc5cd0c0563234.diff

LOG: [Linalg][Vectorization] Add support for linalg vectorization of a tensor.extract case (#107922)

In https://github.com/llvm/llvm-project/pull/102321 we relaxed the
vectorizer so that when checking for contiguous loads we dont always
have a trailing non unit dim. For example in the test case added we have
`tensor<8x1xf32>` which is now a valid candidate for contiguous load.
However, the logic to check contiguous load assumed that only the
trailing dim will be non unit so this PR just updates that logic to find
the actual non unit dim.

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index a376afa5ddab12..2244f29967dfe4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -810,6 +810,28 @@ static Value calculateGatherOffset(RewriterBase &rewriter,
 
 enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather };
 
+/// Find the non-unit dim in a linalgOp.
+/// When executing this hook, it is expected that only one dim will be non-unit.
+/// Other cases (i.e. reading n-D vectors) should've been labelled as gather
+/// loads before calling this method. This is used for finding contiguous loads
+/// (represented as `tensor.extract`) within `linalg.generic` Ops. Note that
+/// this condition is expected to hold for statically shaped Linalg Ops only.
+static uint64_t getNonUnitLoopDim(LinalgOp linalgOp) {
+  uint64_t nonUnitDim = 0;
+  uint64_t countNonUnitDim = 0;
+  for (auto tripCount : llvm::enumerate(linalgOp.getStaticLoopRanges())) {
+    if (tripCount.value() != 1) {
+      nonUnitDim = tripCount.index();
+      countNonUnitDim++;
+    }
+  }
+
+  assert(linalgOp.hasDynamicShape() ||
+         countNonUnitDim == 1 && "For statically shaped Linalg Ops, only one "
+                                 "non-unit loop dim is expected");
+  return nonUnitDim;
+}
+
 /// Checks whether `val` can be used for calculating a loop invariant index.
 static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val,
                                VectorType resType) {
@@ -889,11 +911,12 @@ static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
   Operation *defOp = val.getDefiningOp();
   assert(defOp && "This is neither a block argument nor an operation result");
 
-  // Given the assumption on the loop ranges above, only the trailing loop
-  // index is not constant.
-  auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
+  // Given the assumption on the loop ranges above, we expect only 1 non-unit
+  // loop dim.
+  auto nonUnitLoopDim = getNonUnitLoopDim(linalgOp);
+
   if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
-    foundIndexOp = (indexOp.getDim() == trailingLoopDim);
+    foundIndexOp = (indexOp.getDim() == nonUnitLoopDim);
     return true;
   }
 

diff  --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index bdaa20c3bf971e..ad3a8d9f926082 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -253,6 +253,58 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+#map1 = affine_map<(d0, d1, d2) -> (d0 + d1 + d2)>
+func.func @vectorize_nd_tensor_extract_load_1d_column_vector_using_gather_load(%arg0: tensor<8x128x768xf32>, %arg1 : index) -> tensor<8x1xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = tensor.empty() : tensor<8x1xf32>
+  %1 = linalg.generic {
+    indexing_maps = [#map], 
+    iterator_types = ["parallel", "parallel"]
+  } outs(%0 : tensor<8x1xf32>) {
+  ^bb0(%arg5: f32):
+      %2 = linalg.index 0 : index
+      %3 = linalg.index 1 : index
+      %4 = affine.apply #map1(%arg1, %3, %arg1)
+    %extracted = tensor.extract %arg0[%2, %c0, %4] : tensor<8x128x768xf32>
+    linalg.yield %extracted : f32
+  } -> tensor<8x1xf32>
+  return %1 : tensor<8x1xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.vectorize_children_and_apply_patterns %1 {vectorize_nd_extract} : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_load_1d_column_vector_using_gather_load
+// CHECK-SAME: %[[ARG0:.*]]: tensor<8x128x768xf32>
+// CHECK-SAME: %[[ARG1:.*]]: index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[CST:.*]] = arith.constant dense<768> : vector<1x8xindex>
+// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<128> : vector<1x8xindex>
+// CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32>
+// CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<true> : vector<8x1xi1>
+// CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
+// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<8x1xf32>
+// CHECK: %[[B1:.*]] = vector.broadcast %[[CST_3]] : vector<8xindex> to vector<1x8xindex>
+// CHECK: %[[ADDI_ARG1:.*]] = arith.addi %[[ARG1]], %[[ARG1]] : index
+// CHECK: %[[B2:.*]] = vector.broadcast %[[ADDI_ARG1]] : index to vector<1xindex>
+// CHECK: %[[MULI_1:.*]] = arith.muli %[[B1]], %[[CST_0]] : vector<1x8xindex>
+// CHECK: %[[MULI_2:.*]] = arith.muli %[[MULI_1]], %[[CST]] : vector<1x8xindex>
+// CHECK: %[[T:.*]] = vector.transpose %[[MULI_2]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
+// CHECK: %[[B3:.*]] = vector.broadcast %[[B2]] : vector<1xindex> to vector<8x1xindex>
+// CHECK: %[[ADDI:.*]] = arith.addi %[[B3]], %[[T]] : vector<8x1xindex>
+// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[ADDI]]], %[[CST_2]], %[[PASSTHRU]] : tensor<8x128x768xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32>
+// CHECK: vector.transfer_write %[[GATHER]], %[[EMPTY]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32>
+
 // -----
 
 #map = affine_map<(d0) -> (d0)>


        


More information about the Mlir-commits mailing list