[Mlir-commits] [mlir] [mlir][linalg] Relax scalable vectorization restrictions (PR #117991)
Diego Caballero
llvmlistbot at llvm.org
Thu Nov 28 10:06:02 PST 2024
================
@@ -2022,26 +2022,36 @@ vectorizeScalableVectorPrecondition(Operation *op,
// Cond 3: Look at the configuration in `inputScalableVecDims` and verify that
// it matches one of the supported cases:
- // 1. exactly 1 dim is scalable and that's the _last_ parallel dim
- // 2. exactly 2 dims are scalable and those are the _last two adjacent_
- // parallel dims
- // 3. exactly 1 reduction dim is scalable and that's the last (innermost) dim
+ // 1. Exactly 1 dim is scalable and that's the _last_ non-unit parallel dim
+ // (*).
+ // 2. Exactly 2 dims are scalable and those are the _last two adjacent_
+ // parallel dims.
+ // 3. Exactly 1 reduction dim is scalable and that's the last (innermost)
+ // dim.
// The 2nd restriction above means that only Matmul-like Ops are supported
// when 2 dims are scalable, e.g. :
// * iterators = [parallel, parallel, reduction]
// * scalable flags = [true, true, false]
+ //
+ // (*) Non-unit dims get folded away in practice.
+ // TODO: Relax these conditions as good motivating examples are identified.
- // Find the first scalable flag
- bool seenParalell = false;
+ // Find the first scalable flag, and ...
+ bool seenNonUnitParallel = false;
auto iterators = linalgOp.getIteratorTypesArray();
SmallVector<bool> scalableFlags(inputScalableVecDims);
- while (!scalableFlags.back()) {
- seenParalell |= (iterators.back() == utils::IteratorType::parallel);
+ int64_t idx = scalableFlags.size() - 1;
+ while (!scalableFlags[idx]) {
+ bool isNonUnitDim = (inputVectorSizes[idx] != 1);
+ seenNonUnitParallel |=
+ (iterators[idx] == utils::IteratorType::parallel && isNonUnitDim);
iterators.pop_back();
scalableFlags.pop_back();
+ idx--;
----------------
dcaballe wrote:
nit: --> `--idx`
https://github.com/llvm/llvm-project/pull/117991
More information about the Mlir-commits
mailing list