[PATCH] D78071: [mlir] [VectorOps] Progressive lowering of vector.broadcast
Nicolas Vasilache via Phabricator via llvm-commits
llvm-commits at lists.llvm.org
Thu Apr 16 12:50:21 PDT 2020
nicolasvasilache added a comment.
Sorry for the delay, I had started but got lost in the stretch logic and had to context switch.
So after looking deeper I think the code structure can be reworked to make it lest nesty and easier to follow (+ comments around each return point).
If you follow the suggestion you should end up with 3 duplicated lines for the `loop (insert)` part (see comments).
Given how much more readable I expect this to end up I'd say it's a good tradeoff.
================
Comment at: mlir/lib/Dialect/Vector/VectorTransforms.cpp:1028
+ auto pos = rewriter.getI64ArrayAttr(d);
+ result = rewriter.create<vector::InsertOp>(loc, dstVectorType, bcst,
+ result, pos);
----------------
nit: we could add another builder in VectorOps to create the `ArrayAttr` for us and just use `d`.
================
Comment at: mlir/lib/Dialect/Vector/VectorTransforms.cpp:1051
+ assert(srcRank == dstRank);
+ for (int64_t r = 0; r < dstRank; r++) {
+ if (srcVectorType.getDimSize(r) != dstVectorType.getDimSize(r)) {
----------------
So this loop runs at most once but I got confused.
Can we turn it into a find_if to get the first rank that does not match and save that.
Early replace + exit if not found. then the rest will become easier to follow.
================
Comment at: mlir/lib/Dialect/Vector/VectorTransforms.cpp:1053
+ if (srcVectorType.getDimSize(r) != dstVectorType.getDimSize(r)) {
+ if (srcRank == 1) {
+ auto at = rewriter.getI64ArrayAttr(0);
----------------
Please add a comment that this is the `scalar (as vector<1x>) to vector case`
================
Comment at: mlir/lib/Dialect/Vector/VectorTransforms.cpp:1055
+ auto at = rewriter.getI64ArrayAttr(0);
+ Value ext =
+ rewriter.create<vector::ExtractOp>(loc, eltType, op.source(), at);
----------------
same nit re build
================
Comment at: mlir/lib/Dialect/Vector/VectorTransforms.cpp:1068
+ Value bcst;
+ if (r == 0) {
+ // Stretch at start.
----------------
I find the logic intertwining here quite tricky to follow.
Can we make the flow as follows:
```
if (r == 0) {
extract + broadcast.+ loop (insert)
replace and return
}
loop(extract + broadcast + insert)
replace ans return
```
Repository:
rG LLVM Github Monorepo
CHANGES SINCE LAST ACTION
https://reviews.llvm.org/D78071/new/
https://reviews.llvm.org/D78071
More information about the llvm-commits
mailing list