[PATCH] D78071: [mlir] [VectorOps] Progressive lowering of vector.broadcast

Nicolas Vasilache via Phabricator via llvm-commits llvm-commits at lists.llvm.org
Thu Apr 16 12:50:21 PDT 2020


nicolasvasilache added a comment.

Sorry for the delay, I had started but got lost in the stretch logic and had to context switch.

So after looking deeper I think the code structure can be reworked to make it lest nesty and easier to follow (+ comments around each return point).
If you follow the suggestion you should end up with 3 duplicated lines for the `loop (insert)` part (see comments).
Given how much more readable I expect this to end up I'd say it's a good tradeoff.



================
Comment at: mlir/lib/Dialect/Vector/VectorTransforms.cpp:1028
+        auto pos = rewriter.getI64ArrayAttr(d);
+        result = rewriter.create<vector::InsertOp>(loc, dstVectorType, bcst,
+                                                   result, pos);
----------------
nit: we could add another builder in VectorOps to create the `ArrayAttr` for us and just use `d`.


================
Comment at: mlir/lib/Dialect/Vector/VectorTransforms.cpp:1051
+    assert(srcRank == dstRank);
+    for (int64_t r = 0; r < dstRank; r++) {
+      if (srcVectorType.getDimSize(r) != dstVectorType.getDimSize(r)) {
----------------
So this loop runs at most once but I got confused.
Can we turn it into a find_if to get the first rank that does not match and save that.
Early replace + exit if not found. then the rest will become easier to follow.


================
Comment at: mlir/lib/Dialect/Vector/VectorTransforms.cpp:1053
+      if (srcVectorType.getDimSize(r) != dstVectorType.getDimSize(r)) {
+        if (srcRank == 1) {
+          auto at = rewriter.getI64ArrayAttr(0);
----------------
Please add a comment that this is the `scalar (as vector<1x>) to vector case`


================
Comment at: mlir/lib/Dialect/Vector/VectorTransforms.cpp:1055
+          auto at = rewriter.getI64ArrayAttr(0);
+          Value ext =
+              rewriter.create<vector::ExtractOp>(loc, eltType, op.source(), at);
----------------
same nit re build


================
Comment at: mlir/lib/Dialect/Vector/VectorTransforms.cpp:1068
+        Value bcst;
+        if (r == 0) {
+          // Stretch at start.
----------------
I find the logic intertwining here quite tricky to follow.
Can we make the flow as follows:
```
if (r == 0) {
  extract + broadcast.+ loop (insert)
  replace and return 
}

loop(extract + broadcast + insert)
replace ans return
```


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D78071/new/

https://reviews.llvm.org/D78071





More information about the llvm-commits mailing list