[Mlir-commits] [mlir] [mlir][Linalg] Refine how broadcast dims are treated (PR #99015)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Thu Jul 25 13:12:39 PDT 2024
================
@@ -629,7 +629,21 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value,
loc, value, outputOperand->get(), ValueRange{});
}
- write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
+ // The operand map may contain "zero" results, e.g.:
+ // (d0, d1, d2, d3) -> (d0, d1, d2, 0)
+ // When applied to canonical vector shapes like these:
+ // (1, 16, 16, 4)
+ // we would get:
+ // (1, 16, 16, 0)
+ // Instead, we should extract the following map:
+ // (d0, d1, d2, d3) -> (d0, d1, d2)
+ // This way, the corresponding vector/mask type will be:
+ // vector<1x16x16xty>
+ // rather than:
+ // vector<1x16x16x0xty>
+ auto opOperantMapWithoutZeros = opOperandMap.dropZeros();
+ write =
+ state.maskOperation(rewriter, write, linalgOp, opOperantMapWithoutZeros);
----------------
banach-space wrote:
> What happens with the xfer read counterpart?
It turns out you've already implemented that :)
https://github.com/llvm/llvm-project/blob/2ba3fe7356f065757a2279f65e4ef5c8f1476293/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp#L1327-L1335
> Should we move this logic into maskOperation?
Yes! Sending an update shortly.
https://github.com/llvm/llvm-project/pull/99015
More information about the Mlir-commits
mailing list