[Mlir-commits] [mlir] [mlir][vector] Folder: shape_cast(extract) -> extract (PR #146368)

Andrzej WarzyƄski llvmlistbot at llvm.org
Mon Jun 30 09:20:02 PDT 2025


================
@@ -1696,59 +1696,68 @@ static bool hasZeroDimVectors(Operation *op) {
          llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
 }
 
+/// All BroadcastOps and SplatOps, and ShapeCastOps that only prepends 1s, are
+/// considered 'broadcastlike'.
+static bool isBroadcastLike(Operation *op) {
+  if (isa<BroadcastOp, SplatOp>(op))
+    return true;
+
+  auto shapeCast = dyn_cast<ShapeCastOp>(op);
+  if (!shapeCast)
+    return false;
+
+  VectorType srcType = shapeCast.getSourceVectorType();
+  ArrayRef<int64_t> srcShape = srcType.getShape();
+  uint64_t srcRank = srcType.getRank();
+  ArrayRef<int64_t> dstShape = shapeCast.getType().getShape();
+  return dstShape.size() <= srcRank && dstShape.take_back(srcRank) == srcShape;
+}
+
 /// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
 static Value foldExtractFromBroadcast(ExtractOp extractOp) {
-  Operation *defOp = extractOp.getVector().getDefiningOp();
-  if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
+
+  Operation *broadcastLikeOp = extractOp.getVector().getDefiningOp();
+  if (!broadcastLikeOp || !isBroadcastLike(broadcastLikeOp))
     return Value();
 
-  Value source = defOp->getOperand(0);
-  if (extractOp.getType() == source.getType())
-    return source;
-  auto getRank = [](Type type) {
-    return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
-                                       : 0;
-  };
+  Value src = broadcastLikeOp->getOperand(0);
+
+  // Replace extract(broadcast(X)) with X
+  if (extractOp.getType() == src.getType())
+    return src;
 
-  // If splat or broadcast from a scalar, just return the source scalar.
-  unsigned broadcastSrcRank = getRank(source.getType());
-  if (broadcastSrcRank == 0 && source.getType() == extractOp.getType())
-    return source;
+  // Get required types and ranks in the chain
+  //    src -> broadcastDst -> dst
----------------
banach-space wrote:

I guess `ExtractOp` is missing somewhere here?

https://github.com/llvm/llvm-project/pull/146368


More information about the Mlir-commits mailing list