[Mlir-commits] [mlir] [mlir][vector] Canonicalize broadcast of shape_cast (PR #150523)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Thu Aug 7 13:47:01 PDT 2025
================
@@ -2841,9 +2841,59 @@ LogicalResult BroadcastOp::verify() {
llvm_unreachable("unexpected vector.broadcast op error");
}
+// Return the broadcasted dimensions. Including broadcasts in the leading
+// dimensions and broadcasts through unit dimension (i.e. dim-1).
+static BitVector getBroadcastedDims(ArrayRef<int64_t> srcShape,
+ ArrayRef<int64_t> destShape) {
+ assert(destShape.size() >= srcShape.size());
+ BitVector broadcastedDims(destShape.size());
+ broadcastedDims.set(0, destShape.size() - srcShape.size());
+ auto unitDims = computeBroadcastedUnitDims(srcShape, destShape);
+ for (int64_t dim : unitDims)
+ broadcastedDims.set(dim);
+ return broadcastedDims;
+}
+
+// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
+// with broadcast's result type and the broadcasted dimensions are the same.
+static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp) {
+ auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
+ if (!srcShapeCast)
+ return failure();
+
+ VectorType srcType = srcShapeCast.getSourceVectorType();
+ VectorType destType = broadcastOp.getResultVectorType();
+ // Check type compatibility.
+ if (vector::isBroadcastableTo(srcType, destType) !=
+ BroadcastableToResult::Success)
+ return failure();
+
+ // Given
+ // ```
+ // %s = shape_cast(%x)
+ // %b = broadcast(%s)
+ // ```
+ // If we want to fold %x into %b, the broadcasted dimensions from %x to
+ // %b has to be the same as that of from %s to %b.
+ ArrayRef<int64_t> shapecastShape =
+ srcShapeCast.getResultVectorType().getShape();
+ ArrayRef<int64_t> srcShape = srcType.getShape();
+ ArrayRef<int64_t> destShape = destType.getShape();
----------------
banach-space wrote:
[nit] It would be good to make the names in the comments and in the code consistent. Easier said than done!
```suggestion
// %sc = shape_cast(%src)
// %bc = broadcast(%sc)
// ```
// If we want to fold %x into %b, the broadcasted dimensions from %x to
// %b has to be the same as that of from %s to %b.
ArrayRef<int64_t> scShape =
srcShapeCast.getResultVectorType().getShape();
ArrayRef<int64_t> srcShape = srcType.getShape();
ArrayRef<int64_t> bcShape = destType.getShape();
```
https://github.com/llvm/llvm-project/pull/150523
More information about the Mlir-commits
mailing list