[Mlir-commits] [mlir] [MLIR][Vector]Generalize DropUnitDimFromElementwiseOps (PR #92934)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Mon May 27 03:00:50 PDT 2024
banach-space wrote:
Sorry for the delay, was OOO last week. I've finally managed to catch-up with the context and I have one high-level comment/question.
The pattern that you are updating was designed to help with specific scenarios that are documented here:
https://github.com/llvm/llvm-project/blob/b0b35964042294d407a995a8407ee5ba93ba5a4b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp#L1610-L1638
However, those cases look very different to what you are trying to "fix":
> discussed [here](https://discourse.llvm.org/t/on-improving-arm-sme-lowering-resilience-in-mlir/78543/17?u=nujaa)
(copied from Ben's "Canonical form")
```mlir
%lhsCast = vector.shape_cast %inputLHS : vector<[4]xf32> to vector<[4]x1xf32>
%lhsBcast = vector.broadcast %lhsCast : vector<[4]x1xf32> to vector<[4]x[4]x1xf32>
%lhsT = vector.transpose %lhsBcast, [1, 0, 2] : vector<[4]x[4]x1xf32> to vector<[4]x[4]x1xf32>
%rhsCast = vector.shape_cast %inputRHS : vector<[4]xf32> to vector<1x[4]xf32>
%rhsBcast = vector.broadcast %rhsCast : vector<1x[4]xf32> to vector<[4]x1x[4]xf32>
%rhs = vector.transpose %rhsBcast, [0, 2, 1] : vector<[4]x1x[4]xf32> to vector<[4]x[4]x1xf32>
%mul = arith.mulf %lhsT, %rhs : vector<[4]x[4]x1xf32>
%tileMask = vector.create_mask %lhsDim, %rhsDim : vector<[4]x[4]xi1>
%dropDim = vector.shape_cast %mul : vector<[4]x[4]x1xf32> to vector<[4]x[4]xf32>
%addAcc = arith.addf %acc, %dropDim : vector<[4]x[4]xf32>
%applyMask = arith.select %tileMask, %acc, %addAcc : vector<[4]x[4]xi1>, vector<[4]x[4]xf32>
```
In the example above there aren't that many internal unit dims. Here are 2 examples:
```mlir
%rhsBcast = vector.broadcast %rhsCast : vector<1x[4]xf32> to vector<[4]x1x[4]xf32>
%rhs = vector.transpose %rhsBcast, [0, 2, 1] : vector<[4]x1x[4]xf32> to vector<[4]x[4]x1xf32>
```
Would `DropUnitDimFromElementwiseOps` help here at all? If yes, could you write tests for that? From what I can tell, that won't work as neither `vector.broadcast` nor `vector.transpose` are elementwise. But perhaps I missed something?
https://github.com/llvm/llvm-project/pull/92934
More information about the Mlir-commits
mailing list