[Mlir-commits] [mlir] [mlir][vector] Canonicalize broadcast of shape_cast (PR #150523)

Andrzej WarzyƄski llvmlistbot at llvm.org
Thu Aug 7 13:47:01 PDT 2025


================
@@ -2841,9 +2841,59 @@ LogicalResult BroadcastOp::verify() {
   llvm_unreachable("unexpected vector.broadcast op error");
 }
 
+// Return the broadcasted dimensions. Including broadcasts in the leading
+// dimensions and broadcasts through unit dimension (i.e. dim-1).
+static BitVector getBroadcastedDims(ArrayRef<int64_t> srcShape,
+                                    ArrayRef<int64_t> destShape) {
+  assert(destShape.size() >= srcShape.size());
+  BitVector broadcastedDims(destShape.size());
+  broadcastedDims.set(0, destShape.size() - srcShape.size());
+  auto unitDims = computeBroadcastedUnitDims(srcShape, destShape);
+  for (int64_t dim : unitDims)
+    broadcastedDims.set(dim);
+  return broadcastedDims;
+}
+
+// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
+// with broadcast's result type and the broadcasted dimensions are the same.
+static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp) {
+  auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
+  if (!srcShapeCast)
+    return failure();
+
+  VectorType srcType = srcShapeCast.getSourceVectorType();
+  VectorType destType = broadcastOp.getResultVectorType();
+  // Check type compatibility.
+  if (vector::isBroadcastableTo(srcType, destType) !=
+      BroadcastableToResult::Success)
+    return failure();
+
+  // Given
+  // ```
+  // %s = shape_cast(%x)
+  // %b = broadcast(%s)
+  // ```
+  // If we want to fold %x into %b, the broadcasted dimensions from %x to
+  // %b has to be the same as that of from %s to %b.
+  ArrayRef<int64_t> shapecastShape =
+      srcShapeCast.getResultVectorType().getShape();
+  ArrayRef<int64_t> srcShape = srcType.getShape();
+  ArrayRef<int64_t> destShape = destType.getShape();
----------------
banach-space wrote:

[nit] It would be good to make the names in the comments and in the code consistent. Easier said than done!
```suggestion
  // %sc = shape_cast(%src)
  // %bc = broadcast(%sc)
  // ```
  // If we want to fold %x into %b, the broadcasted dimensions from %x to
  // %b has to be the same as that of from %s to %b.
  ArrayRef<int64_t> scShape =
      srcShapeCast.getResultVectorType().getShape();
  ArrayRef<int64_t> srcShape = srcType.getShape();
  ArrayRef<int64_t> bcShape = destType.getShape();
```

https://github.com/llvm/llvm-project/pull/150523


More information about the Mlir-commits mailing list