[Mlir-commits] [mlir] [mlir][linalg] Restrict scalable vectorisation (PR #98639)
Zhaoshi Zheng
llvmlistbot at llvm.org
Fri Jul 12 09:12:45 PDT 2024
================
@@ -1936,26 +1936,81 @@ vectorizePadOpPrecondition(tensor::PadOp padOp,
return success();
}
-/// Preconditions for scalable vectors.
+/// Preconditions for scalable vectors. This is quite restrictive - it models
+/// the fact that in practice we would only make selected dimensions scalable.
static LogicalResult
vectorizeScalableVectorPrecondition(Operation *op,
ArrayRef<int64_t> inputVectorSizes,
ArrayRef<bool> inputScalableVecDims) {
assert(inputVectorSizes.size() == inputScalableVecDims.size() &&
"Number of input vector sizes and scalable dims doesn't match");
- if (inputVectorSizes.empty())
- return success();
+ size_t numOfScalableDims =
+ llvm::count_if(inputScalableVecDims, [](bool flag) { return flag; });
- bool isScalable = inputScalableVecDims.back();
- if (!isScalable)
+ if (numOfScalableDims == 0)
return success();
- // Only element-wise and 1d depthwise conv ops supported in the presence of
- // scalable dims.
auto linalgOp = dyn_cast<LinalgOp>(op);
- return success(linalgOp && (isElementwise(linalgOp) ||
- isa<linalg::DepthwiseConv1DNwcWcOp>(op)));
+
+ // Cond 1: There's been no need for scalable vectorisation of
+ // non-linalg Ops so far
+ if (!linalgOp)
+ return failure();
+
+ // Cond 2: There's been no need for more than 2 scalable dims so far
+ if (numOfScalableDims > 2)
+ return failure();
+
+ // Cond 3: Look at the configuration in `inputScalableVecDims` and verify that
+ // it matches one of the supported cases:
+ // 1. exactly 1 dim is scalable and that's the _last_ parallel dim
+ // 2. exactly 2 dims are scalable and those are the _last two adjacent_
+ // parallel dims
+ // The 2nd restriction above means that only Matmul-like Ops are supported
+ // when 2 dims are scalable, e.g. :
+ // * iterators = [parallel, parallel, reduction]
+ // * scalable flags = [true, true, false]
+
+ // Find the first scalable flag
+ bool seenParalell = false;
+ auto iterators = linalgOp.getIteratorTypesArray();
+ SmallVector<bool> scalableFlags(inputScalableVecDims);
+ if (!scalableFlags.back()) {
----------------
zhaoshiz wrote:
this `if` seems redandunt
https://github.com/llvm/llvm-project/pull/98639
More information about the Mlir-commits
mailing list