[Mlir-commits] [mlir] [mlir][linalg] Add scalable vectorisation for depthwise convolutions (PR #81625)
Cullen Rhodes
llvmlistbot at llvm.org
Wed Feb 14 08:27:39 PST 2024
================
@@ -3027,20 +3051,74 @@ struct Conv1DGenerator
// (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1,
cSize},
- lhsEltType);
- VectorType rhsType = VectorType::get({kwSize, cSize}, rhsEltType);
- VectorType resType = VectorType::get({nSize, wSize, cSize}, resEltType);
+ lhsEltType, {false, false, scalableChDim});
+ VectorType rhsType =
+ VectorType::get({kwSize, cSize}, rhsEltType,
+ /*scalableDims=*/{false, scalableChDim});
+ VectorType resType =
+ VectorType::get({nSize, wSize, cSize}, resEltType,
+ /*scalableDims=*/{false, false, scalableChDim});
+
+ // Masks the input xfer Op along the channel dim, iff the corresponding
+ // scalable flag is set.
+ auto maybeMaskXferOp = [&](ArrayRef<int64_t> maskShape,
+ ArrayRef<bool> scalableDims,
+ Operation *opToMask) {
+ bool scalableChDim = scalableDims.back();
+ if (!scalableChDim)
+ return opToMask;
+
+ auto maskType =
+ VectorType::get(maskShape, rewriter.getI1Type(), scalableDims);
+
+ SmallVector<OpFoldResult> mixedSourceDims =
+ hasTensorSemantics
+ ? TypeSwitch<Operation *, SmallVector<OpFoldResult>>(opToMask)
+ .Case<vector::TransferReadOp>([&](auto readOp) {
+ return tensor::getMixedSizes(rewriter, loc,
+ readOp.getSource());
+ })
+ .Case<vector::TransferWriteOp>([&](auto writeOp) {
+ return tensor::getMixedSizes(rewriter, loc,
+ writeOp.getOperand(1));
+ })
+ : TypeSwitch<Operation *, SmallVector<OpFoldResult>>(opToMask)
+ .Case<vector::TransferReadOp>([&](auto readOp) {
+ return memref::getMixedSizes(rewriter, loc,
+ readOp.getSource());
+ })
+ .Case<vector::TransferWriteOp>([&](auto writeOp) {
+ return memref::getMixedSizes(rewriter, loc,
+ writeOp.getOperand(1));
+ });
+
+ Value maskOp =
+ rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
+
+ return mlir::vector::maskOperation(rewriter, opToMask, maskOp);
+ };
// Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0,
// 0].
Value lhs = rewriter.create<vector::TransferReadOp>(
loc, lhsType, lhsShaped, ValueRange{zero, zero, zero});
+ auto maybeMaskedLHS = maybeMaskXferOp(
+ lhsType.getShape(),
+ /*scalableDims=*/{false, false, scalableChDim}, lhs.getDefiningOp());
----------------
c-rhodes wrote:
```suggestion
lhsType.getScalableDims(), lhs.getDefiningOp());
```
https://github.com/llvm/llvm-project/pull/81625
More information about the Mlir-commits
mailing list