[Mlir-commits] [mlir] [mlir][vector] Refine vectorisation of tensor.extract (PR #109580)

Andrzej WarzyƄski llvmlistbot at llvm.org
Fri Oct 18 06:02:19 PDT 2024


================
@@ -810,27 +810,35 @@ static Value calculateGatherOffset(RewriterBase &rewriter,
 
 enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather };
 
-/// Find the non-unit dim in a linalgOp.
-/// When executing this hook, it is expected that only one dim will be non-unit.
-/// Other cases (i.e. reading n-D vectors) should've been labelled as gather
-/// loads before calling this method. This is used for finding contiguous loads
-/// (represented as `tensor.extract`) within `linalg.generic` Ops. Note that
-/// this condition is expected to hold for statically shaped Linalg Ops only.
-static uint64_t getNonUnitLoopDim(LinalgOp linalgOp) {
-  uint64_t nonUnitDim = 0;
-  uint64_t countNonUnitDim = 0;
-  for (auto tripCount : llvm::enumerate(linalgOp.getStaticLoopRanges())) {
-    if (tripCount.value() != 1) {
-      nonUnitDim = tripCount.index();
-      countNonUnitDim++;
-    }
-  }
-
+/// Find the index of the trailing non-unit dim in linalgOp. This hook is used
+/// when checking whether `tensor.extract` Op (within a `linalg.generic` Op)
+/// represents a contiguous load operaiton.
+///
+/// Note that when calling this hook, it is assumed that the output vector is
+/// effectively 1D. Other cases (i.e. reading n-D vectors) should've been
+/// labelled as a gather load before entering this method.
+///
+/// Following on from the above, it is assumed that:
+///   * for statically shaped loops, when no masks are used, only one dim is !=
+///   1 (that's what the shape of the output vector is based on).
+///   * for dynamically shaped loops, there might be more non-unit dims
+///   as the output vector type is user-specified.
+///
+/// TODO: Statically shaped loops + vector masking
+static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp) {
+  SmallVector<int64_t> loopRanges = linalgOp.getStaticLoopRanges();
   assert(linalgOp.hasDynamicShape() ||
-         countNonUnitDim == 1 && "For statically shaped Linalg Ops, only one "
-                                 "non-unit loop dim is expected");
-  (void)countNonUnitDim;
-  return nonUnitDim;
+         llvm::count_if(loopRanges, [](int64_t dim) { return dim != 1; }) ==
+                 1 &&
+             "For statically shaped Linalg Ops, only one "
+             "non-unit loop dim is expected");
+
+  size_t idx = loopRanges.size() - 1;
+  for (; idx >= 0; idx--)
----------------
banach-space wrote:

Thanks! https://github.com/llvm/llvm-project/pull/112900

https://github.com/llvm/llvm-project/pull/109580


More information about the Mlir-commits mailing list