[Mlir-commits] [mlir] [MLIR][Vector] Added ToElementsOp::fold for broadcast->to_elements pattern rewrite. (PR #160318)
Keshav Vinayak Jha
llvmlistbot at llvm.org
Tue Sep 30 01:56:04 PDT 2025
================
@@ -2395,11 +2395,103 @@ foldToElementsFromElements(ToElementsOp toElementsOp,
return success();
}
+/// Folds vector.to_elements(vector.broadcast(%x)) by creating a new
+/// vector.to_elements on the source and remapping results according to
+/// broadcast semantics.
+///
+/// Cases handled:
+/// - %x is a scalar: replicate the scalar across all results.
+/// - %x is a vector: create to_elements on source and remap/duplicate results.
+static LogicalResult
+foldToElementsOfBroadcast(ToElementsOp toElementsOp,
+ SmallVectorImpl<OpFoldResult> &results) {
+ auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
+ if (!bcastOp)
+ return failure();
+
+ auto resultVecType = cast<VectorType>(toElementsOp.getSource().getType());
+ // Bail on scalable vectors.
+ if (resultVecType.getNumScalableDims() != 0)
+ return failure();
+
+ // Case 1: scalar broadcast → replicate scalar across all results.
+ if (!isa<VectorType>(bcastOp.getSource().getType())) {
+ Value scalar = bcastOp.getSource();
+ results.assign(resultVecType.getNumElements(), scalar);
+ return success();
+ }
+
+ // Case 2: vector broadcast → create to_elements on source and remap.
+ auto srcVecType = cast<VectorType>(bcastOp.getSource().getType());
+ if (srcVecType.getNumScalableDims() != 0)
+ return failure();
+
+ // Create a temporary to_elements to get the source elements for mapping.
+ // Change the operand to the broadcast source.
+ OpBuilder builder(toElementsOp);
+ auto srcElems = builder.create<ToElementsOp>(toElementsOp.getLoc(),
+ bcastOp.getSource());
----------------
keshavvinayak01 wrote:
As per our discussion, we need a canonicalizer for it.
https://github.com/llvm/llvm-project/pull/160318
More information about the Mlir-commits
mailing list