[Mlir-commits] [mlir] [mlir][vector] Folder: shape_cast(extract) -> extract (PR #146368)
Andrzej WarzyĆski
llvmlistbot at llvm.org
Mon Jul 21 09:28:49 PDT 2025
================
@@ -1707,59 +1707,98 @@ static bool hasZeroDimVectors(Operation *op) {
llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
}
-/// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
+/// All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepend
+/// 1s, are considered to be 'broadcastlike'.
+static bool isBroadcastLike(Operation *op) {
+ if (isa<BroadcastOp, SplatOp>(op))
+ return true;
+
+ auto shapeCast = dyn_cast<ShapeCastOp>(op);
+ if (!shapeCast)
+ return false;
+
+ // Check that shape_cast **only** prepends 1s, like (2,3) -> (1,1,2,3).
+ // Checking that the destination shape has a prefix of 1s is not sufficient,
+ // for example (2,3) -> (1,3,2) is not broadcastlike. A sufficient condition
+ // is that the source shape is a suffix of the destination shape.
+ VectorType srcType = shapeCast.getSourceVectorType();
+ ArrayRef<int64_t> srcShape = srcType.getShape();
+ uint64_t srcRank = srcType.getRank();
+ ArrayRef<int64_t> dstShape = shapeCast.getType().getShape();
+ return dstShape.size() >= srcRank && dstShape.take_back(srcRank) == srcShape;
+}
+
+/// Fold extract(broadcast(X)) to either extract(X) or just X.
+///
+/// Example:
+///
+/// broadcast extract [1][2]
+/// (3, 4) --------> (2, 3, 4) ----------------> (4)
+///
+/// becomes
+/// extract [1]
+/// (3,4) -------------------------------------> (4)
+///
+///
+/// The variable names used in this implementation correspond to the above
+/// shapes as,
+///
+/// - (3, 4) is `input` shape.
+/// - (2, 3, 4) is `broadcast` shape.
+/// - (4) is `extract` shape.
+///
+/// This folding is possible when the suffix of `input` shape is the same as
+/// `extract` shape.
static Value foldExtractFromBroadcast(ExtractOp extractOp) {
+
Operation *defOp = extractOp.getVector().getDefiningOp();
- if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
+ if (!defOp || !isBroadcastLike(defOp))
return Value();
- Value source = defOp->getOperand(0);
- if (extractOp.getType() == source.getType())
- return source;
- auto getRank = [](Type type) {
- return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
- : 0;
- };
+ Value input = defOp->getOperand(0);
- // If splat or broadcast from a scalar, just return the source scalar.
- unsigned broadcastSrcRank = getRank(source.getType());
- if (broadcastSrcRank == 0 && source.getType() == extractOp.getType())
- return source;
+ // Replace extract(broadcast(X)) with X
+ if (extractOp.getType() == input.getType())
+ return input;
- unsigned extractResultRank = getRank(extractOp.getType());
- if (extractResultRank > broadcastSrcRank)
- return Value();
- // Check that the dimension of the result haven't been broadcasted.
- auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType());
- auto broadcastVecType = llvm::dyn_cast<VectorType>(source.getType());
- if (extractVecType && broadcastVecType &&
- extractVecType.getShape() !=
- broadcastVecType.getShape().take_back(extractResultRank))
+ // Get required types and ranks in the chain
+ // input -> broadcast -> extract
----------------
banach-space wrote:
```suggestion
// input -> broadcast -> extract
// (scalars are treated as rank-0).
```
https://github.com/llvm/llvm-project/pull/146368
More information about the Mlir-commits
mailing list