[Mlir-commits] [mlir] [mlir][linalg] Relax scalable vectorization restrictions (PR #117991)

Diego Caballero llvmlistbot at llvm.org
Thu Nov 28 10:06:02 PST 2024


================
@@ -2022,26 +2022,36 @@ vectorizeScalableVectorPrecondition(Operation *op,
 
   // Cond 3: Look at the configuration in `inputScalableVecDims` and verify that
   // it matches one of the supported cases:
-  //  1. exactly 1 dim is scalable and that's the _last_ parallel dim
-  //  2. exactly 2 dims are scalable and those are the _last two adjacent_
-  //     parallel dims
-  //  3. exactly 1 reduction dim is scalable and that's the last (innermost) dim
+  //  1. Exactly 1 dim is scalable and that's the _last_ non-unit parallel dim
+  //    (*).
+  //  2. Exactly 2 dims are scalable and those are the _last two adjacent_
+  //     parallel dims.
+  //  3. Exactly 1 reduction dim is scalable and that's the last (innermost)
+  //  dim.
   // The 2nd restriction above means that only Matmul-like Ops are supported
   // when 2 dims are scalable, e.g. :
   //    * iterators = [parallel, parallel, reduction]
   //    * scalable flags = [true, true, false]
+  //
+  // (*) Non-unit dims get folded away in practice.
+  // TODO: Relax these conditions as good motivating examples are identified.
 
-  // Find the first scalable flag
-  bool seenParalell = false;
+  // Find the first scalable flag, and ...
+  bool seenNonUnitParallel = false;
   auto iterators = linalgOp.getIteratorTypesArray();
   SmallVector<bool> scalableFlags(inputScalableVecDims);
-  while (!scalableFlags.back()) {
-    seenParalell |= (iterators.back() == utils::IteratorType::parallel);
+  int64_t idx = scalableFlags.size() - 1;
+  while (!scalableFlags[idx]) {
+    bool isNonUnitDim = (inputVectorSizes[idx] != 1);
+    seenNonUnitParallel |=
+        (iterators[idx] == utils::IteratorType::parallel && isNonUnitDim);
 
     iterators.pop_back();
     scalableFlags.pop_back();
+    idx--;
----------------
dcaballe wrote:

nit: --> `--idx`

https://github.com/llvm/llvm-project/pull/117991


More information about the Mlir-commits mailing list