[Mlir-commits] [mlir] [MLIR][Vector] Add DropUnitDimFromBroadcastOp pattern (PR #92938)
Benjamin Maxwell
llvmlistbot at llvm.org
Fri Jun 21 05:55:39 PDT 2024
================
@@ -1703,6 +1703,67 @@ struct DropUnitDimFromElementwiseOps final
}
};
+/// Drops unit non scalable dimensions inside a broadcastOp which are shared
+/// among source and result with shape_casts.
+/// The newly inserted shape_cast Ops fold (before Op) and then
+/// restore the unit dim after Op. Source type is required to be a vector.
+///
+/// Ex:
+/// ```
+/// %bc = vector.broadcast %arg0 : vector<1x4xf32> to vector<1x3x1x4xf32>
+/// %cast = vector.shape_cast %bc : vector<1x3x1x4xf32> to vector<1x3x4xf32>
+/// ```
+///
+/// Gets converted to:
+///
+/// ```
+/// %sc_arg = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
+/// %bc = vector.broadcast %arg : vector<4xf32> to vector<1x3x4xf32>
+/// %cast_new = vector.shape_cast %bc : vector<1x3x4xf32> to
+/// vector<1x3x1x4xf32>
+/// %cast = vector.shape_cast %cast_new : vector<1x3x1x4xf32> to
+/// vector<1x3x4xf32>
+/// ```
+/// %cast_new and %cast can be folded away.
+struct DropUnitDimFromBroadcastOp final
+ : public OpRewritePattern<vector::BroadcastOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
+ PatternRewriter &rewriter) const override {
+ auto srcVecTy = dyn_cast<VectorType>(broadcastOp.getSourceType());
+ if (!srcVecTy)
+ return failure();
+ auto resVecTy = broadcastOp.getResultVectorType();
+ auto srcVecTyBuilder = VectorType::Builder(srcVecTy);
+ auto resVecTyBuilder = VectorType::Builder(resVecTy);
+ auto broadcastedUnitDims = broadcastOp.computeBroadcastedUnitDims();
+ // Reversing allows us to remove dims from the back without keeping track of
+ // removed dimensions.
+ for (const auto [reversedIndex, dim] :
+ llvm::enumerate(llvm::reverse(srcVecTy.getShape()))) {
+ unsigned srcDimIndex = srcVecTy.getRank() - reversedIndex - 1;
+ unsigned resDimIndex = resVecTy.getRank() - reversedIndex - 1;
+ if (dim == 1 && !srcVecTy.getScalableDims()[srcDimIndex] &&
+ !broadcastedUnitDims.contains(srcDimIndex)) {
+ srcVecTyBuilder.dropDim(srcDimIndex);
+ resVecTyBuilder.dropDim(resDimIndex);
+ }
+ }
+
+ if (VectorType(srcVecTyBuilder) == srcVecTy)
+ return failure();
+ auto loc = broadcastOp->getLoc();
+ auto newSource = rewriter.create<vector::ShapeCastOp>(
+ loc, VectorType(srcVecTyBuilder), broadcastOp.getSource());
----------------
MacDue wrote:
nit: avoid constructing the new vector type twice:
```suggestion
auto newSrcVecTy = VectorType(srcVecTyBuilder);
if (newSrcVecTy == srcVecTy)
return failure();
auto loc = broadcastOp->getLoc();
auto newSource = rewriter.create<vector::ShapeCastOp>(
loc, newSrcVecTy, broadcastOp.getSource());
```
https://github.com/llvm/llvm-project/pull/92938
More information about the Mlir-commits
mailing list