[Mlir-commits] [mlir] [mlir][Linalg] Refine how broadcast dims are treated (PR #99015)
Diego Caballero
llvmlistbot at llvm.org
Fri Jul 19 15:33:43 PDT 2024
================
@@ -629,7 +629,21 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value,
loc, value, outputOperand->get(), ValueRange{});
}
- write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
+ // The operand map may contain "zero" results, e.g.:
+ // (d0, d1, d2, d3) -> (d0, d1, d2, 0)
+ // When applied to canonical vector shapes like these:
+ // (1, 16, 16, 4)
+ // we would get:
+ // (1, 16, 16, 0)
+ // Instead, we should extract the following map:
+ // (d0, d1, d2, d3) -> (d0, d1, d2)
+ // This way, the corresponding vector/mask type will be:
+ // vector<1x16x16xty>
+ // rather than:
+ // vector<1x16x16x0xty>
+ auto opOperantMapWithoutZeros = opOperandMap.dropZeros();
+ write =
+ state.maskOperation(rewriter, write, linalgOp, opOperantMapWithoutZeros);
----------------
dcaballe wrote:
What happens with the xfer read counterpart? Should we move this logic into `maskOperation`?
https://github.com/llvm/llvm-project/pull/99015
More information about the Mlir-commits
mailing list