[Mlir-commits] [mlir] [mlir][vector] Folder: shape_cast(extract) -> extract (PR #146368)
James Newling
llvmlistbot at llvm.org
Mon Jun 30 10:53:46 PDT 2025
================
@@ -1696,59 +1696,68 @@ static bool hasZeroDimVectors(Operation *op) {
llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
}
+/// All BroadcastOps and SplatOps, and ShapeCastOps that only prepends 1s, are
+/// considered 'broadcastlike'.
+static bool isBroadcastLike(Operation *op) {
+ if (isa<BroadcastOp, SplatOp>(op))
+ return true;
+
+ auto shapeCast = dyn_cast<ShapeCastOp>(op);
+ if (!shapeCast)
+ return false;
+
+ VectorType srcType = shapeCast.getSourceVectorType();
+ ArrayRef<int64_t> srcShape = srcType.getShape();
+ uint64_t srcRank = srcType.getRank();
+ ArrayRef<int64_t> dstShape = shapeCast.getType().getShape();
+ return dstShape.size() <= srcRank && dstShape.take_back(srcRank) == srcShape;
----------------
newling wrote:
That'd pass for
```
(3,2000) -> (1,1,2000,3)
```
which is not broadcast-like. The key is that it **only** prepends 1s.
I made this mistake in my first implementation! So probably worth adding a negative case, and making the comment clearer (I'm on it).
https://github.com/llvm/llvm-project/pull/146368
More information about the Mlir-commits
mailing list