[Mlir-commits] [mlir] [mlir][vector] Canonicalize broadcast of shape_cast (PR #150523)

James Newling llvmlistbot at llvm.org
Wed Aug 6 09:32:43 PDT 2025


================
@@ -2841,9 +2841,59 @@ LogicalResult BroadcastOp::verify() {
   llvm_unreachable("unexpected vector.broadcast op error");
 }
 
+// Return the broadcasted dimensions. Including broadcasts in the leading
+// dimensions and broadcasts through unit dimension (i.e. dim-1).
+static BitVector getBroadcastedDims(ArrayRef<int64_t> srcShape,
+                                    ArrayRef<int64_t> destShape) {
+  assert(destShape.size() >= srcShape.size());
+  BitVector broadcastedDims(destShape.size());
+  broadcastedDims.set(0, destShape.size() - srcShape.size());
+  auto unitDims = computeBroadcastedUnitDims(srcShape, destShape);
+  for (int64_t dim : unitDims)
+    broadcastedDims.set(dim);
+  return broadcastedDims;
+}
+
+// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
+// with broadcast's result type and the broadcasted dimensions are the same.
+static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp) {
+  auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
----------------
newling wrote:

I _think_ this is the same as saying (where srcShape -> shapecastShape -> destShape)

1) rank(srcShape) <= rank(destShape)
2) srcShape and shapeCastShape are the same, except that one has some 1's prepended. i.e. where R = min(srcShape.rank, shapeCastShape.rank), last R dimensions of srcShape and shapeCastCast are the same. 

If so, would be more intuitive I think. If not, can you please provided a counterexample? 

https://github.com/llvm/llvm-project/pull/150523


More information about the Mlir-commits mailing list