[Mlir-commits] [mlir] [mlir][linalg] Enable masked vectorisation for depthwise convolutions (PR #81625)
Diego Caballero
llvmlistbot at llvm.org
Wed Mar 6 15:29:16 PST 2024
================
@@ -3027,20 +3069,68 @@ struct Conv1DGenerator
// (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
cSize},
- lhsEltType);
- VectorType rhsType = VectorType::get({kwSize, cSize}, rhsEltType);
- VectorType resType = VectorType::get({nSize, wSize, cSize}, resEltType);
+ lhsEltType, /*scalableDims=*/{false, false, scalableChDim});
+ VectorType rhsType =
+ VectorType::get({kwSize, cSize}, rhsEltType,
+ /*scalableDims=*/{false, scalableChDim});
+ VectorType resType =
+ VectorType::get({nSize, wSize, cSize}, resEltType,
+ /*scalableDims=*/{false, false, scalableChDim});
+
+ // Masks the input xfer Op along the channel dim, iff the corresponding
+ // scalable flag is set.
+ auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape,
+ ArrayRef<bool> scalableDims,
+ Operation *opToMask) {
+ if (!useMasking)
+ return opToMask;
+ auto maskType =
+ VectorType::get(maskShape, rewriter.getI1Type(), scalableDims);
+ SmallVector<OpFoldResult> mixedSourceDims =
+ cast<LinalgOp>(op).hasPureTensorSemantics()
+ ? TypeSwitch<Operation *, SmallVector<OpFoldResult>>(opToMask)
+ .Case<vector::TransferReadOp>([&](auto readOp) {
+ return tensor::getMixedSizes(rewriter, loc,
+ readOp.getSource());
+ })
+ .Case<vector::TransferWriteOp>([&](auto writeOp) {
+ return tensor::getMixedSizes(rewriter, loc,
+ writeOp.getOperand(1));
+ })
+ : TypeSwitch<Operation *, SmallVector<OpFoldResult>>(opToMask)
+ .Case<vector::TransferReadOp>([&](auto readOp) {
+ return memref::getMixedSizes(rewriter, loc,
+ readOp.getSource());
+ })
+ .Case<vector::TransferWriteOp>([&](auto writeOp) {
+ return memref::getMixedSizes(rewriter, loc,
+ writeOp.getOperand(1));
+ });
----------------
dcaballe wrote:
Move to utility? Couldn't this getMixedSizes be implemented on the common ShapedType?
https://github.com/llvm/llvm-project/pull/81625
More information about the Mlir-commits
mailing list