[Mlir-commits] [mlir] [mlir][vector] Folder: shape_cast(extract) -> extract (PR #146368)
James Newling
llvmlistbot at llvm.org
Mon Jun 30 10:46:38 PDT 2025
================
@@ -1696,59 +1696,68 @@ static bool hasZeroDimVectors(Operation *op) {
llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
}
+/// All BroadcastOps and SplatOps, and ShapeCastOps that only prepends 1s, are
+/// considered 'broadcastlike'.
+static bool isBroadcastLike(Operation *op) {
+ if (isa<BroadcastOp, SplatOp>(op))
+ return true;
+
+ auto shapeCast = dyn_cast<ShapeCastOp>(op);
+ if (!shapeCast)
+ return false;
+
+ VectorType srcType = shapeCast.getSourceVectorType();
+ ArrayRef<int64_t> srcShape = srcType.getShape();
+ uint64_t srcRank = srcType.getRank();
+ ArrayRef<int64_t> dstShape = shapeCast.getType().getShape();
+ return dstShape.size() <= srcRank && dstShape.take_back(srcRank) == srcShape;
+}
+
/// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
static Value foldExtractFromBroadcast(ExtractOp extractOp) {
- Operation *defOp = extractOp.getVector().getDefiningOp();
- if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
+
+ Operation *broadcastLikeOp = extractOp.getVector().getDefiningOp();
+ if (!broadcastLikeOp || !isBroadcastLike(broadcastLikeOp))
return Value();
- Value source = defOp->getOperand(0);
- if (extractOp.getType() == source.getType())
- return source;
- auto getRank = [](Type type) {
- return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
- : 0;
- };
+ Value src = broadcastLikeOp->getOperand(0);
+
+ // Replace extract(broadcast(X)) with X
+ if (extractOp.getType() == src.getType())
+ return src;
- // If splat or broadcast from a scalar, just return the source scalar.
- unsigned broadcastSrcRank = getRank(source.getType());
- if (broadcastSrcRank == 0 && source.getType() == extractOp.getType())
- return source;
+ // Get required types and ranks in the chain
+ // src -> broadcastDst -> dst
+ auto srcType = llvm::dyn_cast<VectorType>(src.getType());
+ auto dstType = llvm::dyn_cast<VectorType>(extractOp.getType());
+ unsigned srcRank = srcType ? srcType.getRank() : 0;
+ unsigned broadcastDstRank = extractOp.getSourceVectorType().getRank();
+ unsigned dstRank = dstType ? dstType.getRank() : 0;
- unsigned extractResultRank = getRank(extractOp.getType());
- if (extractResultRank > broadcastSrcRank)
+ // Cannot do without the broadcast if overall the rank increases.
+ if (dstRank > srcRank)
return Value();
- // Check that the dimension of the result haven't been broadcasted.
- auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType());
- auto broadcastVecType = llvm::dyn_cast<VectorType>(source.getType());
- if (extractVecType && broadcastVecType &&
- extractVecType.getShape() !=
- broadcastVecType.getShape().take_back(extractResultRank))
+
+ assert(srcType && "src must be a vector type because of previous checks");
----------------
newling wrote:
It's the intervening rank check which allows this assertion. Code is like
```c++
if (extractOp.getType() == src.getType())
return src;
[...]
unsigned srcRank = srcType ? srcType.getRank() : 0;
[...]
if (dstRank > srcRank)
return Value();
[...]
assert(srcType && "src must be a vector type because of previous checks");
```
Suppose src is scalar at the point of assertion.
Then srcRank is 0, so dstRank is 0.
If dstRank is 0, then dst is scalar.
If they're both scalar, we would have returned early (same types).
Contradiction -- src is not scalar.
TBH this is reasoning is probably too complicated, and could be replaced with a `if (...) return Value()`
```
if (srcType) return Value();
```
https://github.com/llvm/llvm-project/pull/146368
More information about the Mlir-commits
mailing list