[Mlir-commits] [mlir] [mlir][vector] Canonicalize broadcast of shape_cast (PR #150523)
Min-Yih Hsu
llvmlistbot at llvm.org
Thu Aug 7 15:35:46 PDT 2025
================
@@ -2841,9 +2841,59 @@ LogicalResult BroadcastOp::verify() {
llvm_unreachable("unexpected vector.broadcast op error");
}
+// Return the broadcasted dimensions. Including broadcasts in the leading
+// dimensions and broadcasts through unit dimension (i.e. dim-1).
+static BitVector getBroadcastedDims(ArrayRef<int64_t> srcShape,
+ ArrayRef<int64_t> destShape) {
+ assert(destShape.size() >= srcShape.size());
+ BitVector broadcastedDims(destShape.size());
+ broadcastedDims.set(0, destShape.size() - srcShape.size());
+ auto unitDims = computeBroadcastedUnitDims(srcShape, destShape);
+ for (int64_t dim : unitDims)
+ broadcastedDims.set(dim);
+ return broadcastedDims;
+}
+
+// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
+// with broadcast's result type and the broadcasted dimensions are the same.
+static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp) {
+ auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
----------------
mshockwave wrote:
We can roughly breakdown this into five cases
(1) srcShape is "broken" up into multiple non-one dimensions. e.g. <4x1> -> <2x2>
(2) srcShape is prepended by one or more ones
(3) srcShape is _appended_ by one or more ones
(4) One or more leading dimensions in srcShape were removed
(5) One or more _trailing_ dimensions in srcShape were removed
Note that multiple cases could be applied at the same time. For instance <2x1> -> <1x2> is removing the trailing dimension before appending a new one.
Case (1) is easy: srcShape will never be broadcastable w.r.t destShape. Because the rule of broadcast effectively mandates the source dimensions to be a "subset" of destination dimensions, modulo dimensions that are one. And changing the dimension values will violate that.
I think case (2), (4) are conjugate. Because broadcasting at those prepended dimensions that are one is the same as broadcasting toward missing (leading) dimensions; similarly, broadcasting at missing leading dimensions is the same as broadcasting at ones that were once there. Therefore, they are allowed.
Case (3) and (5) are similar, both of them change the "neighboring" elements in the highest dimension -- an element either becomes or not become 'singleton'. For instance [A, B] turns into [[A], [B]] when we cast from <2> to <2x1>. In which case element A turn from having a neighbor B into singleton. Whether it's singleton or not is important, because an element that is not singleton will always be broadcasted with its neighbor. On the other hand, being singleton means that it could be replicated on its own. Since this alters the broadcasting behavior, once this appears -- even combined with other cases like <1x2> -> <2x1> mentioned earlier -- we could not do the folding. Note that this also coincides with my current rule -- the original replicated dimensions have to match with the new replicated dimensions.
The bottom line is: I think your new rule is correct, I'm gonna update to it.
https://github.com/llvm/llvm-project/pull/150523
More information about the Mlir-commits
mailing list