[Mlir-commits] [mlir] [MLIR][Vector] Add DropUnitDimFromBroadcastOp pattern (PR #92938)
Hugo Trachino
llvmlistbot at llvm.org
Thu Jun 20 07:12:03 PDT 2024
================
@@ -1703,6 +1703,66 @@ struct DropUnitDimFromElementwiseOps final
}
};
+/// Drops unit non scalable dimensions inside a broadcastOp which are shared
+/// among source and result with shape_casts.
+/// The newly inserted shape_cast Ops fold (before Op) and then
+/// restore the unit dim after Op. Source type is required to be a vector.
+///
+/// Ex:
+/// ```
+/// %bc = vector.broadcast %arg0 : vector<1x4xf32> to vector<1x3x1x4xf32>
+/// %cast = vector.shape_cast %bc : vector<1x3x1x4xf32> to vector<1x3x4xf32>
+/// ```
+///
+/// Gets converted to:
+///
+/// ```
+/// %sc_arg = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
+/// %bc = vector.broadcast %arg : vector<4xf32> to vector<1x3x4xf32>
+/// %cast_new = vector.shape_cast %bc : vector<1x3x4xf32> to
+/// vector<1x3x1x4xf32>
+/// %cast = vector.shape_cast %cast_new : vector<1x3x1x4xf32> to
+/// vector<1x3x4xf32>
+/// ```
+/// %cast_new and %cast can be folded away.
+struct DropUnitDimFromBroadcastOp final
+ : public OpRewritePattern<vector::BroadcastOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
+ PatternRewriter &rewriter) const override {
+ auto srcVecTy = dyn_cast<VectorType>(broadcastOp.getSourceType());
+ if (!srcVecTy)
+ return failure();
+ auto resVecTy = broadcastOp.getResultVectorType();
+ auto srcVecTyBuilder = VectorType::Builder(srcVecTy);
+ auto resVecTyBuilder = VectorType::Builder(resVecTy);
----------------
nujaa wrote:
I agree, I tried implementing it with shapes to be appended. But, considering one has to rebuild also scalableDims for both the new source and result type, I found it was generating lots code which is hidden thanks to dropDim.
Also, for some reason, I was able to generate the base of the new resultShape creating a subvector of it with :
```
SmallVector<int64_t> newResShape =
llvm::to_vector(resVecTy.getShape().drop_back(srcVecTy.getRank()));
```
but for Scalable Dims I get some errors like this and I dont think I should be changing the behaviour of SmallVector. I suspect it comes from the way ScalableDims are defined.
```
llvm/include/llvm/ADT/SmallVector.h:1317:11: error: type 'decltype(__cont.begin())' (aka 'const bool *') cannot be narrowed to 'bool' in initializer list [-Wc++11-narrowing]
```
In order to fix it, I needed to create an ugly vector inserting explicit casts like
```
SmallVector<bool> newResScalableDims = {
static_cast<bool>(resVecTy.getScalableDims().begin()),
static_cast<bool>(resVecTy.getScalableDims().drop_back(srcVecTy.getRank()).end())};
```
If you want, I can push my solution on top and we revert it if we prefer it as it currently is.
https://github.com/llvm/llvm-project/pull/92938
More information about the Mlir-commits
mailing list