[Mlir-commits] [mlir] [mlir][linalg][conv] Flatten the channel dimension when vectorizing (PR #71918)
Nicolas Vasilache
llvmlistbot at llvm.org
Mon Nov 13 05:35:56 PST 2023
================
@@ -2936,6 +2938,176 @@ struct Conv1DGenerator
.getOperation();
}
+ /// Generate a vector implementation for ("flatten channel dim"):
+ /// ```
+ /// Op def: ( n, w, c, kw)
+ /// Iters: ({Par(), Par(), Par(), Red()})
+ /// Layout: {{n, 1 * w + dilationW * kw, c}, {kw, c}, {n, w, c}}
+ /// ```
+ /// c of the input/output is collapsed with w. kw is always unrolled and
+ /// broadcast to match w.
+ ///
+ /// TODO: Add support for non-unit stride/dilation
+ /// TODO: w (resp. kw) is unrolled when the strideW ( resp. dilationW) is
+ /// > 1.
+ FailureOr<Operation *> depthwiseConvFlatten() {
+ if (!valid)
+ return rewriter.notifyMatchFailure(op, "unvectorizable depthwise conv");
+
+ int64_t nSize, iSize, wSize, cSize, kwSize;
+ // kernel{kw, c}
+ bindShapeDims(rhsShapedType, kwSize, cSize);
+ // out{n, w, c}
+ bindShapeDims(resShapedType, nSize, wSize);
+ // in{n, w, c}
+ bindShapeDims(lhsShapedType, nSize, iSize);
+
+ vector::TransferWriteOp write;
+ Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+
+ if (strideW == 1)
+ return rewriter.notifyMatchFailure(
+ op, "Non-unit strides are not supported yet");
+ if (dilationW == 1)
+ return rewriter.notifyMatchFailure(
+ op, "Non-unit dilations are not supported yet");
+
+ Type lhsEltType = lhsShapedType.getElementType();
+ Type rhsEltType = rhsShapedType.getElementType();
+ Type resEltType = resShapedType.getElementType();
+ VectorType rhsType = VectorType::get({kwSize, cSize}, rhsEltType);
+ VectorType lhsType = VectorType::get(
+ {nSize,
+ // iw = (ow * sw + kw * dw - 1) * c
+ // (i.e. 16 convolved with 3 (@stride 1 dilation 1) -> 14)
+ (((wSize - 1) * strideW + 1) + ((kwSize - 1) * dilationW + 1) - 1) *
+ cSize},
+ lhsEltType);
+
+ VectorType resType = VectorType::get({nSize, wSize * cSize}, resEltType);
+
+ Value res, lhs, lhsFlat, resFlat;
+ // Read rhs slice of size {kw, c} @ [0, 0].
+ Value rhs = rewriter.create<vector::TransferReadOp>(loc, rhsType, rhsShaped,
+ ValueRange{zero, zero});
+
+ SmallVector<ReassociationIndices> reassociation = {{0}, {1, 2}};
+
+ // Flatten w and c dimensions
+ auto lhsTypeCollapsed = VectorType::get({nSize, iSize * cSize}, lhsEltType);
+ auto linalgOp = dyn_cast<LinalgOp>(op);
+ lhsFlat =
+ linalgOp.hasTensorSemantics()
+ ? (Value)rewriter.create<tensor::CollapseShapeOp>(
+ loc,
+ RankedTensorType::get(lhsTypeCollapsed.getShape(),
+ lhsEltType),
+ lhsShaped, reassociation)
+ : (Value)rewriter.create<memref::CollapseShapeOp>(
+ loc, MemRefType::get(lhsTypeCollapsed.getShape(), lhsEltType),
+ lhsShaped, reassociation);
+ resFlat =
+ linalgOp.hasTensorSemantics()
+ ? (Value)rewriter.create<tensor::CollapseShapeOp>(
+ loc, RankedTensorType::get(resType.getShape(), resEltType),
+ resShaped, reassociation)
+ : (Value)rewriter.create<memref::CollapseShapeOp>(
+ loc, MemRefType::get(resType.getShape(), resEltType),
+ resShaped, reassociation);
+
+ // Read lhs slice of size {n, (w * wSize + kw * dilationW) * c} @ [0,
+ // 0].
+ lhs = rewriter.create<vector::TransferReadOp>(loc, lhsType, lhsFlat,
+ ValueRange{zero, zero});
+ // Read res slice of size {n, w * c} @ [0, 0].
+ res = rewriter.create<vector::TransferReadOp>(loc, resType, resFlat,
+ ValueRange{zero, zero});
+
+ //===------------------------------------------------------------------===//
+ // Begin vector-only rewrite part
+ //===------------------------------------------------------------------===//
+ // Unroll along kw and read slices of lhs and rhs.
+ SmallVector<Value> lhsVals, rhsVals, resVals;
+ // Extract lhs slice of size {n, wSizeStep * c}
+ // @ [0, (sw * w + dw * kw) * cSize].
+ for (int64_t kw = 0; kw < kwSize; ++kw) {
----------------
nicolasvasilache wrote:
The current implementation has too much copy to my taste.
Is it possible to implement common helper functions resembling `extractConvInputSlices` that would hide away all the complexity and duplication?
https://github.com/llvm/llvm-project/pull/71918
More information about the Mlir-commits
mailing list