[Mlir-commits] [mlir] b47d178 - [mlir][vector] Refine vectorisation of tensor.extract (#109580)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Sep 24 06:03:35 PDT 2024


Author: Andrzej WarzyƄski
Date: 2024-09-24T14:03:30+01:00
New Revision: b47d1787b51f55d69ef1b4f88e72cd54af451649

URL: https://github.com/llvm/llvm-project/commit/b47d1787b51f55d69ef1b4f88e72cd54af451649
DIFF: https://github.com/llvm/llvm-project/commit/b47d1787b51f55d69ef1b4f88e72cd54af451649.diff

LOG: [mlir][vector] Refine vectorisation of tensor.extract (#109580)

This PR fixes a bug in `isLoopInvariantIdx`. It makes sure that the
following case is vectorised as `vector.gather` (as opposed to
attempting a contiguous load):
```mlir
  func.func @index_from_output_column_vector_gather_load(%src: tensor<8x128xf32>) -> tensor<8x1xf32> {
    %c0 = arith.constant 0 : index
    %0 = tensor.empty() : tensor<8x1xf32>
    %res = linalg.generic {
      indexing_maps = [#map],
      iterator_types = ["parallel", "parallel"]
    } outs(%0 : tensor<8x1xf32>) {
    ^bb0(%arg1: f32):
        %1 = linalg.index 0 : index
      %extracted = tensor.extract %src[%1, %c0] : tensor<8x128xf32>
        linalg.yield %extracted : f32
    } -> tensor<8x1xf32>
    return %res : tensor<8x1xf32>
  }
```

Specifically, when looking for loop-invariant indices in
`tensor.extract` Ops, any `linalg.index` Op that's used in address
colcluation should only access loop dims that are == 1. In the example
above, the following does not meet that criteria:
```mlir
  %1 = linalg.index 0 : index
```

Note that this PR also effectively addresses the issue fixed in #107922,
i.e. exercised by:
  * `@vectorize_nd_tensor_extract_load_1d_column_vector_using_gather_load`

`getNonUnitLoopDim` introduced in #107922 is still valid though. In
fact, it is required to identify that the following case is a contiguous
load:
```mlir
  func.func @index_from_output_column_vector_contiguous_load(%src: tensor<8x128xf32>) -> tensor<8x1xf32> {
    %c0 = arith.constant 0 : index
    %0 = tensor.empty() : tensor<8x1xf32>
    %res = linalg.generic {
      indexing_maps = [#map],
      iterator_types = ["parallel", "parallel"]
    } outs(%0 : tensor<8x1xf32>) {
    ^bb0(%arg1: f32):
        %1 = linalg.index 0 : index
      %extracted = tensor.extract %src[%c0, %1] : tensor<8x128xf32>
        linalg.yield %extracted : f32
    } -> tensor<8x1xf32>
    return %res : tensor<8x1xf32>
  }
```
Some logic is still missing to lower the above to
`vector.transfer_read`, so it is conservatively lowered to
`vector.gather` instead (see TODO in
`getTensorExtractMemoryAccessPattern`).

There's a few additional changes:
  * `getNonUnitLoopDim` is simplified and renamed as
    `getTrailingNonUnitLoopDimIdx`, additional comments are added (note
    that the functionality didn't change);
  * extra comments in a few places, variable names in comments update to
    use Markdown (which is the preferred approach in MLIR).

This is a follow-on for:
  * https://github.com/llvm/llvm-project/pull/107922
  * https://github.com/llvm/llvm-project/pull/102321

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
    mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 6800a0fec278c6..c332307da4d333 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -810,27 +810,35 @@ static Value calculateGatherOffset(RewriterBase &rewriter,
 
 enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather };
 
-/// Find the non-unit dim in a linalgOp.
-/// When executing this hook, it is expected that only one dim will be non-unit.
-/// Other cases (i.e. reading n-D vectors) should've been labelled as gather
-/// loads before calling this method. This is used for finding contiguous loads
-/// (represented as `tensor.extract`) within `linalg.generic` Ops. Note that
-/// this condition is expected to hold for statically shaped Linalg Ops only.
-static uint64_t getNonUnitLoopDim(LinalgOp linalgOp) {
-  uint64_t nonUnitDim = 0;
-  uint64_t countNonUnitDim = 0;
-  for (auto tripCount : llvm::enumerate(linalgOp.getStaticLoopRanges())) {
-    if (tripCount.value() != 1) {
-      nonUnitDim = tripCount.index();
-      countNonUnitDim++;
-    }
-  }
-
+/// Find the index of the trailing non-unit dim in linalgOp. This hook is used
+/// when checking whether `tensor.extract` Op (within a `linalg.generic` Op)
+/// represents a contiguous load operation.
+///
+/// Note that when calling this hook, it is assumed that the output vector is
+/// effectively 1D. Other cases (i.e. reading n-D vectors) should've been
+/// labelled as a gather load before entering this method.
+///
+/// Following on from the above, it is assumed that:
+///   * for statically shaped loops, when no masks are used, only one dim is !=
+///   1 (that's what the shape of the output vector is based on).
+///   * for dynamically shaped loops, there might be more non-unit dims
+///   as the output vector type is user-specified.
+///
+/// TODO: Statically shaped loops + vector masking
+static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp) {
+  SmallVector<int64_t> loopRanges = linalgOp.getStaticLoopRanges();
   assert(linalgOp.hasDynamicShape() ||
-         countNonUnitDim == 1 && "For statically shaped Linalg Ops, only one "
-                                 "non-unit loop dim is expected");
-  (void)countNonUnitDim;
-  return nonUnitDim;
+         llvm::count_if(loopRanges, [](int64_t dim) { return dim != 1; }) ==
+                 1 &&
+             "For statically shaped Linalg Ops, only one "
+             "non-unit loop dim is expected");
+
+  size_t idx = loopRanges.size() - 1;
+  for (; idx >= 0; idx--)
+    if (loopRanges[idx] != 1)
+      break;
+
+  return idx;
 }
 
 /// Checks whether `val` can be used for calculating a loop invariant index.
@@ -854,11 +862,11 @@ static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val,
   assert(defOp && "This is neither a block argument nor an operation result");
 
   // IndexOp is loop invariant as long as its result remains constant across
-  // iterations. Given the assumptions on the loop ranges above, only the
-  // trailing loop dim ever changes.
-  auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
-  if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp))
-    return (indexOp.getDim() != trailingLoopDim);
+  // iterations. Note that for dynamic shapes, the corresponding dim will also
+  // be conservatively treated as != 1.
+  if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
+    return linalgOp.getStaticLoopRanges()[indexOp.getDim()] == 1;
+  }
 
   auto *ancestor = block->findAncestorOpInBlock(*defOp);
 
@@ -877,7 +885,7 @@ static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val,
   return result;
 }
 
-/// Check whether \p val could be used for calculating the trailing index for a
+/// Check whether `val` could be used for calculating the trailing index for a
 /// contiguous load operation.
 ///
 /// There are currently 3 types of values that are allowed here:
@@ -886,13 +894,14 @@ static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val,
 ///   3. results of basic arithmetic operations (linear and continuous)
 ///      involving 1., 2. and 3.
 /// This method returns True if indeed only such values are used in calculating
-/// \p val.
+/// `val.`
 ///
 /// Additionally, the trailing index for a contiguous load operation should
 /// increment by 1 with every loop iteration, i.e. be based on:
 ///   * `linalg.index <dim>` ,
-/// where <dim> is the trailing dim of the iteration space. \p foundIndexOp is
-/// updated to `true` when such an op is found.
+/// where <dim> is the trailing non-unit dim of the iteration space (this way,
+/// `linalg.index <dim>` increments by 1 with every loop iteration).
+/// `foundIndexOp` is updated to `true` when such Op is found.
 static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
                                 bool &foundIndexOp, VectorType resType) {
 
@@ -912,12 +921,10 @@ static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
   Operation *defOp = val.getDefiningOp();
   assert(defOp && "This is neither a block argument nor an operation result");
 
-  // Given the assumption on the loop ranges above, we expect only 1 non-unit
-  // loop dim.
-  auto nonUnitLoopDim = getNonUnitLoopDim(linalgOp);
-
   if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
-    foundIndexOp = (indexOp.getDim() == nonUnitLoopDim);
+    auto loopDimThatIncrementsByOne = getTrailingNonUnitLoopDimIdx(linalgOp);
+
+    foundIndexOp = (indexOp.getDim() == loopDimThatIncrementsByOne);
     return true;
   }
 
@@ -1012,7 +1019,10 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
   bool foundIndexOp = false;
   bool isContiguousLoad = isContiguousLoadIdx(linalgOp, extractOpTrailingIdx,
                                               foundIndexOp, resType);
-  isContiguousLoad &= foundIndexOp;
+  // TODO: Support generating contiguous loads for column vectors - that will
+  // require adding a permutation map to tranfer_read Ops.
+  bool isRowVector = resType.getShape().back() != 1;
+  isContiguousLoad &= (foundIndexOp && isRowVector);
 
   if (isContiguousLoad) {
     LDBG("Found contigous load: " << extractOp);
@@ -1073,6 +1083,11 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
   //  b. contiguous loads.
   // Both cases use vector.transfer_read.
 
+  assert(llvm::count_if(resultType.getShape(),
+                        [](uint64_t dim) { return dim != 1; }) &&
+         "Contiguous loads and scalar loads + broadcast only support 1-D "
+         "vectors ATM!");
+
   // Collect indices for `vector.transfer_read`. At this point, the indices will
   // either be scalars or would have been broadcast to vectors matching the
   // result type. For indices that are vectors, there are two options:

diff  --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index ad3a8d9f926082..2c56b7139fec49 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -307,6 +307,96 @@ module attributes {transform.with_named_sequence} {
 
 // -----
 
+// Reading a 1D column vector (hence a candidate for a contiguous load), but given
+// %1, it's a gather load.
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @index_from_output_column_vector_gather_load(%src: tensor<8x128xf32>) -> tensor<8x1xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = tensor.empty() : tensor<8x1xf32>
+  %res = linalg.generic {
+    indexing_maps = [#map],
+    iterator_types = ["parallel", "parallel"]
+  } outs(%0 : tensor<8x1xf32>) {
+  ^bb0(%arg1: f32):
+      %1 = linalg.index 0 : index
+    %extracted = tensor.extract %src[%1, %c0] : tensor<8x128xf32>
+      linalg.yield %extracted : f32
+  } -> tensor<8x1xf32>
+  return %res : tensor<8x1xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg2: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg2 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.vectorize_children_and_apply_patterns %1 {vectorize_nd_extract} : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK-LABEL:   func.func @index_from_output_column_vector_gather_load(
+// CHECK-SAME:      %[[SRC:.*]]: tensor<8x128xf32>) -> tensor<8x1xf32> {
+// CHECK:           %[[C128:.*]] = arith.constant dense<128> : vector<1x8xindex>
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[PASS_THRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32>
+// CHECK:           %[[MASK:.*]] = arith.constant dense<true> : vector<8x1xi1>
+// CHECK:           %[[IDX_VEC:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
+// CHECK:           %[[OUT:.*]] = tensor.empty() : tensor<8x1xf32>
+// CHECK:           %[[B:.*]] = vector.broadcast %[[IDX_VEC]] : vector<8xindex> to vector<1x8xindex>
+// CHECK:           %[[MUL:.*]] = arith.muli %[[B]], %[[C128]] : vector<1x8xindex>
+// CHECK:           %[[TR:.*]] = vector.transpose %[[MUL]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
+// CHECK:           %[[GATHER:.*]] = vector.gather %[[SRC]]{{\[}}%[[C0]], %[[C0]]] {{\[}}%[[TR]]], %[[MASK]], %[[PASS_THRU]] : tensor<8x128xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32>
+// CHECK:           %[[RES:.*]] = vector.transfer_write %[[GATHER]], %[[OUT]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32>
+// CHECK:           return %[[RES]] : tensor<8x1xf32>
+
+// -----
+
+// Same as above, but the access indices have been swapped and hence this _is_
+// a contiguous load. Currently not supported and lowered as vector.gather
+// instead.
+// TODO: Make sure that this is lowered as a contiguous load.
+
+#map = affine_map<(d0, d1) -> (d0, d1)>
+func.func @index_from_output_column_vector_contiguous_load(%src: tensor<8x128xf32>) -> tensor<8x1xf32> {
+  %c0 = arith.constant 0 : index
+  %0 = tensor.empty() : tensor<8x1xf32>
+  %res = linalg.generic {
+    indexing_maps = [#map],
+    iterator_types = ["parallel", "parallel"]
+  } outs(%0 : tensor<8x1xf32>) {
+  ^bb0(%arg1: f32):
+      %1 = linalg.index 0 : index
+    %extracted = tensor.extract %src[%c0, %1] : tensor<8x128xf32>
+      linalg.yield %extracted : f32
+  } -> tensor<8x1xf32>
+  return %res : tensor<8x1xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg2: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg2 : (!transform.any_op) -> !transform.any_op
+    %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
+    %2 = transform.structured.vectorize_children_and_apply_patterns %1 {vectorize_nd_extract} : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// CHECK-LABEL:   func.func @index_from_output_column_vector_contiguous_load(
+// CHECK-SAME:      %[[SRC:.*]]: tensor<8x128xf32>) -> tensor<8x1xf32> {
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[PASS_THRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32>
+// CHECK:           %[[MASK:.*]] = arith.constant dense<true> : vector<8x1xi1>
+// CHECK:           %[[IDX_VEC:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
+// CHECK:           %[[OUT:.*]] = tensor.empty() : tensor<8x1xf32>
+// CHECK:           %[[B:.*]] = vector.broadcast %[[IDX_VEC]] : vector<8xindex> to vector<1x8xindex>
+// CHECK:           %[[TR:.*]] = vector.transpose %[[B]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
+// CHECK:           %[[GATHER:.*]] = vector.gather %[[SRC]]{{\[}}%[[C0]], %[[C0]]] {{\[}}%[[TR]]], %[[MASK]], %[[PASS_THRU]] : tensor<8x128xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32>
+// CHECK:           %[[RES:.*]] = vector.transfer_write %[[GATHER]], %[[OUT]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32>
+// CHECK:           return %[[RES]] : tensor<8x1xf32>
+
+// -----
+
 #map = affine_map<(d0) -> (d0)>
 func.func @vectorize_nd_tensor_extract_contiguous_and_gather(%arg0: tensor<6xf32>, %arg1: tensor<5xi32>) -> tensor<5xf32> {
  %c5 = arith.constant 5 : index


        


More information about the Mlir-commits mailing list