[Mlir-commits] [mlir] [MLIR][Vector] Add DropUnitDimFromBroadcastOp pattern (PR #92938)

Benjamin Maxwell llvmlistbot at llvm.org
Thu Jun 20 15:03:19 PDT 2024


================
@@ -1703,6 +1703,66 @@ struct DropUnitDimFromElementwiseOps final
   }
 };
 
+/// Drops unit non scalable dimensions inside a broadcastOp which are shared
+/// among source and result with shape_casts.
+/// The newly inserted shape_cast Ops fold (before Op) and then
+/// restore the unit dim after Op. Source type is required to be a vector.
+///
+/// Ex:
+/// ```
+///  %bc = vector.broadcast %arg0 : vector<1x4xf32> to vector<1x3x1x4xf32>
+///  %cast = vector.shape_cast %bc : vector<1x3x1x4xf32> to vector<1x3x4xf32>
+/// ```
+///
+/// Gets converted to:
+///
+/// ```
+///  %sc_arg = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32>
+///  %bc = vector.broadcast %arg : vector<4xf32> to vector<1x3x4xf32>
+///  %cast_new = vector.shape_cast %bc : vector<1x3x4xf32> to
+///    vector<1x3x1x4xf32>
+///  %cast = vector.shape_cast %cast_new : vector<1x3x1x4xf32> to
+///    vector<1x3x4xf32>
+/// ```
+/// %cast_new and %cast can be folded away.
+struct DropUnitDimFromBroadcastOp final
+    : public OpRewritePattern<vector::BroadcastOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
+                                PatternRewriter &rewriter) const override {
+    auto srcVecTy = dyn_cast<VectorType>(broadcastOp.getSourceType());
+    if (!srcVecTy)
+      return failure();
+    auto resVecTy = broadcastOp.getResultVectorType();
+    auto srcVecTyBuilder = VectorType::Builder(srcVecTy);
+    auto resVecTyBuilder = VectorType::Builder(resVecTy);
----------------
MacDue wrote:

Btw, I've been trying to make writing code like these easier for scalable dims (for a while now :sweat_smile:). With my current attempt #96236, I think you'd be able to rewrite this as: 
```c++
auto srcDims = VectorDimList::from(srcVecTy);
auto resDims = VectorDimList::from(resVecTy);
auto rankDiff = resDims.size() - srcDims.size();

SmallVector<VectorDim> newSrcDims;
SmallVector<VectorDim> newResDims(resDims.takeFront(rankDiff));

auto broadcastedUnitDims = broadcastOp.computeBroadcastedUnitDims();
for (auto [idx, dim] : llvm::enumerate(srcDims)) {
  if (dim == VectorDim::getFixed(1) && !broadcastedUnitDims.contains(idx)) {
    newResDims.push_back(dim);
    newSrcDims.push_back(dim);
  }
}

auto newSourceType = ScalableVectorType::get(newSrcDims, srcVecTy.getElementType());
auto newResultType = ScalableVectorType::get(newResDims, srcVecTy.getElementType());
```

https://github.com/llvm/llvm-project/pull/92938


More information about the Mlir-commits mailing list