[Mlir-commits] [mlir] [mlir][vector] Canonicalize broadcast of shape_cast (PR #150523)
Andrzej Warzyński
llvmlistbot at llvm.org
Sun Aug 10 09:00:31 PDT 2025
================
@@ -2841,9 +2841,47 @@ LogicalResult BroadcastOp::verify() {
llvm_unreachable("unexpected vector.broadcast op error");
}
+// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
+// with broadcast's result type and shape_cast only adds or removes ones in the
+// leading dimensions.
+static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp) {
+ auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
+ if (!srcShapeCast)
+ return failure();
+
+ VectorType srcType = srcShapeCast.getSourceVectorType();
+ VectorType destType = broadcastOp.getResultVectorType();
+ // Check type compatibility.
+ if (vector::isBroadcastableTo(srcType, destType) !=
+ BroadcastableToResult::Success)
+ return failure();
+
+ ArrayRef<int64_t> srcShape = srcType.getShape();
+ ArrayRef<int64_t> shapecastShape =
+ srcShapeCast.getResultVectorType().getShape();
+ // Trailing dimensions should be the same if shape_cast only alters the
+ // leading dimensions.
+ unsigned numTrailingDims = std::min(srcShape.size(), shapecastShape.size());
+ if (!llvm::equal(srcShape.take_back(numTrailingDims),
+ shapecastShape.take_back(numTrailingDims)))
+ return failure();
+
+ assert(all_of(srcShape.drop_back(numTrailingDims),
+ [](int64_t E) { return E == 1; }) &&
+ all_of(shapecastShape.drop_back(numTrailingDims),
+ [](int64_t E) { return E == 1; }) &&
+ "ill-formed shape_cast");
----------------
banach-space wrote:
[nit] Unlike LLVM, we use `camelCasel` in MLIR for variable names. So, `E` -> `e` (rather confusing, I know). If you want to avoid `e` (less readable than `E` IMHO), you could try `E` -> `dim` 🤷🏻
https://github.com/llvm/llvm-project/pull/150523
More information about the Mlir-commits
mailing list