[Mlir-commits] [mlir] [mlir][vector] Refine vectorisation of tensor.extract (PR #109580)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Fri Oct 18 06:02:19 PDT 2024
================
@@ -810,27 +810,35 @@ static Value calculateGatherOffset(RewriterBase &rewriter,
enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather };
-/// Find the non-unit dim in a linalgOp.
-/// When executing this hook, it is expected that only one dim will be non-unit.
-/// Other cases (i.e. reading n-D vectors) should've been labelled as gather
-/// loads before calling this method. This is used for finding contiguous loads
-/// (represented as `tensor.extract`) within `linalg.generic` Ops. Note that
-/// this condition is expected to hold for statically shaped Linalg Ops only.
-static uint64_t getNonUnitLoopDim(LinalgOp linalgOp) {
- uint64_t nonUnitDim = 0;
- uint64_t countNonUnitDim = 0;
- for (auto tripCount : llvm::enumerate(linalgOp.getStaticLoopRanges())) {
- if (tripCount.value() != 1) {
- nonUnitDim = tripCount.index();
- countNonUnitDim++;
- }
- }
-
+/// Find the index of the trailing non-unit dim in linalgOp. This hook is used
+/// when checking whether `tensor.extract` Op (within a `linalg.generic` Op)
+/// represents a contiguous load operaiton.
+///
+/// Note that when calling this hook, it is assumed that the output vector is
+/// effectively 1D. Other cases (i.e. reading n-D vectors) should've been
+/// labelled as a gather load before entering this method.
+///
+/// Following on from the above, it is assumed that:
+/// * for statically shaped loops, when no masks are used, only one dim is !=
+/// 1 (that's what the shape of the output vector is based on).
+/// * for dynamically shaped loops, there might be more non-unit dims
+/// as the output vector type is user-specified.
+///
+/// TODO: Statically shaped loops + vector masking
+static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp) {
+ SmallVector<int64_t> loopRanges = linalgOp.getStaticLoopRanges();
assert(linalgOp.hasDynamicShape() ||
- countNonUnitDim == 1 && "For statically shaped Linalg Ops, only one "
- "non-unit loop dim is expected");
- (void)countNonUnitDim;
- return nonUnitDim;
+ llvm::count_if(loopRanges, [](int64_t dim) { return dim != 1; }) ==
+ 1 &&
+ "For statically shaped Linalg Ops, only one "
+ "non-unit loop dim is expected");
+
+ size_t idx = loopRanges.size() - 1;
+ for (; idx >= 0; idx--)
----------------
banach-space wrote:
Thanks! https://github.com/llvm/llvm-project/pull/112900
https://github.com/llvm/llvm-project/pull/109580
More information about the Mlir-commits
mailing list