[Mlir-commits] [mlir] [mlir][linalg] Restrict scalable vectorisation (PR #98639)

Zhaoshi Zheng llvmlistbot at llvm.org
Fri Jul 12 09:12:50 PDT 2024


================
@@ -1936,26 +1936,81 @@ vectorizePadOpPrecondition(tensor::PadOp padOp,
   return success();
 }
 
-/// Preconditions for scalable vectors.
+/// Preconditions for scalable vectors. This is quite restrictive - it models
+/// the fact that in practice we would only make selected dimensions scalable.
 static LogicalResult
 vectorizeScalableVectorPrecondition(Operation *op,
                                     ArrayRef<int64_t> inputVectorSizes,
                                     ArrayRef<bool> inputScalableVecDims) {
   assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
          "Number of input vector sizes and scalable dims doesn't match");
 
-  if (inputVectorSizes.empty())
-    return success();
+  size_t numOfScalableDims =
+      llvm::count_if(inputScalableVecDims, [](bool flag) { return flag; });
 
-  bool isScalable = inputScalableVecDims.back();
-  if (!isScalable)
+  if (numOfScalableDims == 0)
     return success();
 
-  // Only element-wise and 1d depthwise conv ops supported in the presence of
-  // scalable dims.
   auto linalgOp = dyn_cast<LinalgOp>(op);
-  return success(linalgOp && (isElementwise(linalgOp) ||
-                              isa<linalg::DepthwiseConv1DNwcWcOp>(op)));
+
+  // Cond 1: There's been no need for scalable vectorisation of
+  // non-linalg Ops so far
+  if (!linalgOp)
+    return failure();
+
+  // Cond 2: There's been no need for more than 2 scalable dims so far
+  if (numOfScalableDims > 2)
+    return failure();
+
+  // Cond 3: Look at the configuration in `inputScalableVecDims` and verify that
+  // it matches one of the supported cases:
+  //  1. exactly 1 dim is scalable and that's the _last_ parallel dim
+  //  2. exactly 2 dims are scalable and those are the _last two adjacent_
+  //     parallel dims
+  // The 2nd restriction above means that only Matmul-like Ops are supported
+  // when 2 dims are scalable, e.g. :
+  //    * iterators = [parallel, parallel, reduction]
+  //    * scalable flags = [true, true, false]
+
+  // Find the first scalable flag
+  bool seenParalell = false;
+  auto iterators = linalgOp.getIteratorTypesArray();
+  SmallVector<bool> scalableFlags(inputScalableVecDims);
+  if (!scalableFlags.back()) {
+    while (!scalableFlags.back()) {
+      seenParalell |= (iterators.back() == utils::IteratorType::parallel);
+
+      iterators.pop_back();
+      scalableFlags.pop_back();
+    }
+  }
+
+  // TODO: Support scalable vectorisation for reduction dims
+  if (iterators.back() == utils::IteratorType::reduction)
+    return failure();
+
+  // If this is not the _last_ parallel dim, 1. above is not met
+  if (seenParalell)
+    return failure();
+
+  // If present, check the 2nd scalable dim. ATM, only Matmul-like Ops are
+  // supported for which expect the folowing config:
+  //    * iterators = [parallel, parallel, reduction]
+  //    * scalable flags = [true, true, false]
+  if (numOfScalableDims == 2) {
+    scalableFlags.pop_back();
+    iterators.pop_back();
+
+    if (!scalableFlags.back() ||
+        (iterators.back() != utils::IteratorType::parallel))
+      return failure();
+  }
+
+  // Cond 4: Only element-wise and 1d depthwise conv ops supported in the
----------------
zhaoshiz wrote:

element-wise, matmul-like, and 1d depthwise conv ops

https://github.com/llvm/llvm-project/pull/98639


More information about the Mlir-commits mailing list