[Mlir-commits] [mlir] [mlir][vector] Folder: shape_cast(extract) -> extract (PR #146368)

Andrzej WarzyƄski llvmlistbot at llvm.org
Mon Jun 30 09:20:03 PDT 2025


================
@@ -1696,59 +1696,71 @@ static bool hasZeroDimVectors(Operation *op) {
          llvm::any_of(op->getResultTypes(), hasZeroDimVectorType);
 }
 
+/// All BroadcastOps and SplatOps, as well as ShapeCastOps that only prepends
+/// 1s, are considered 'broadcastlike'.
+static bool isBroadcastLike(Operation *op) {
+  if (isa<BroadcastOp, SplatOp>(op))
+    return true;
+
+  auto shapeCast = dyn_cast<ShapeCastOp>(op);
+  if (!shapeCast)
+    return false;
+
+  // Check that it just prepends 1s, like (2,3) -> (1,1,2,3).
+  // Condition 1: dst has hight rank.
+  // Condition 2: src shape is a suffix of dst shape.
+  VectorType srcType = shapeCast.getSourceVectorType();
+  ArrayRef<int64_t> srcShape = srcType.getShape();
+  uint64_t srcRank = srcType.getRank();
+  ArrayRef<int64_t> dstShape = shapeCast.getType().getShape();
+  return dstShape.size() >= srcRank && dstShape.take_back(srcRank) == srcShape;
+}
+
 /// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
 static Value foldExtractFromBroadcast(ExtractOp extractOp) {
-  Operation *defOp = extractOp.getVector().getDefiningOp();
-  if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
+
+  Operation *broadcastLikeOp = extractOp.getVector().getDefiningOp();
+  if (!broadcastLikeOp || !isBroadcastLike(broadcastLikeOp))
     return Value();
 
-  Value source = defOp->getOperand(0);
-  if (extractOp.getType() == source.getType())
-    return source;
-  auto getRank = [](Type type) {
-    return llvm::isa<VectorType>(type) ? llvm::cast<VectorType>(type).getRank()
-                                       : 0;
-  };
+  Value src = broadcastLikeOp->getOperand(0);
+
+  // Replace extract(broadcast(X)) with X
+  if (extractOp.getType() == src.getType())
+    return src;
 
-  // If splat or broadcast from a scalar, just return the source scalar.
-  unsigned broadcastSrcRank = getRank(source.getType());
-  if (broadcastSrcRank == 0 && source.getType() == extractOp.getType())
-    return source;
+  // Get required types and ranks in the chain
+  //    src -> broadcastDst -> dst
+  auto srcType = llvm::dyn_cast<VectorType>(src.getType());
+  auto dstType = llvm::dyn_cast<VectorType>(extractOp.getType());
+  unsigned srcRank = srcType ? srcType.getRank() : 0;
+  unsigned broadcastDstRank = extractOp.getSourceVectorType().getRank();
+  unsigned dstRank = dstType ? dstType.getRank() : 0;
 
-  unsigned extractResultRank = getRank(extractOp.getType());
-  if (extractResultRank > broadcastSrcRank)
+  // Cannot do without the broadcast if overall the rank increases.
+  if (dstRank > srcRank)
     return Value();
-  // Check that the dimension of the result haven't been broadcasted.
-  auto extractVecType = llvm::dyn_cast<VectorType>(extractOp.getType());
-  auto broadcastVecType = llvm::dyn_cast<VectorType>(source.getType());
-  if (extractVecType && broadcastVecType &&
-      extractVecType.getShape() !=
-          broadcastVecType.getShape().take_back(extractResultRank))
+
+  assert(srcType && "src must be a vector type because of previous checks");
+
+  ArrayRef<int64_t> srcShape = srcType.getShape();
+  if (dstType && dstType.getShape() != srcShape.take_back(dstRank))
----------------
banach-space wrote:

Could you add a comment explaining what case this is?

https://github.com/llvm/llvm-project/pull/146368


More information about the Mlir-commits mailing list