[Mlir-commits] [mlir] Added folder for broadcast->to_elements (PR	#160318)
    Keshav Vinayak Jha 
    llvmlistbot at llvm.org
       
    Tue Sep 23 08:02:17 PDT 2025
    
    
  
https://github.com/keshavvinayak01 created https://github.com/llvm/llvm-project/pull/160318
## Description
Adds `::fold` for the new `vector.to_elements` op, folding `broadcast` into `to_elements` or no-op wherever possible.
>From 691bf891868130e8eb66953de2648e1f1befcd31 Mon Sep 17 00:00:00 2001
From: keshavvinayak01 <keshavvinayakjha at gmail.com>
Date: Tue, 23 Sep 2025 14:59:27 +0000
Subject: [PATCH] Added folder for broadcast->to_elements
Signed-off-by: keshavvinayak01 <keshavvinayakjha at gmail.com>
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 94 +++++++++++++++++++++-
 mlir/test/Dialect/Vector/canonicalize.mlir | 26 ++++++
 2 files changed, 119 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 347141e2773b8..4ac61418b97a5 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2395,11 +2395,103 @@ foldToElementsFromElements(ToElementsOp toElementsOp,
   return success();
 }
 
+/// Folds vector.to_elements(vector.broadcast(%x)) by creating a new
+/// vector.to_elements on the source and remapping results according to
+/// broadcast semantics.
+///
+/// Cases handled:
+///  - %x is a scalar: replicate the scalar across all results.
+///  - %x is a vector: create to_elements on source and remap/duplicate results.
+static LogicalResult
+foldToElementsOfBroadcast(ToElementsOp toElementsOp,
+                          SmallVectorImpl<OpFoldResult> &results) {
+  auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
+  if (!bcastOp)
+    return failure();
+
+  auto resultVecType = cast<VectorType>(toElementsOp.getSource().getType());
+  // Bail on scalable vectors.
+  if (resultVecType.getNumScalableDims() != 0)
+    return failure();
+
+  // Case 1: scalar broadcast → replicate scalar across all results.
+  if (!isa<VectorType>(bcastOp.getSource().getType())) {
+    Value scalar = bcastOp.getSource();
+    results.assign(resultVecType.getNumElements(), scalar);
+    return success();
+  }
+
+  // Case 2: vector broadcast → create to_elements on source and remap.
+  auto srcVecType = cast<VectorType>(bcastOp.getSource().getType());
+  if (srcVecType.getNumScalableDims() != 0)
+    return failure();
+
+  // Create a temporary to_elements to get the source elements for mapping.
+  // Change the operand to the broadcast source.
+  OpBuilder builder(toElementsOp);
+  auto srcElems = builder.create<ToElementsOp>(toElementsOp.getLoc(),
+                                               bcastOp.getSource());
+
+  ArrayRef<int64_t> dstShape = resultVecType.getShape();
+  ArrayRef<int64_t> srcShape = srcVecType.getShape();
+
+  // Quick broadcastability check with right-aligned shapes.
+  unsigned dstRank = dstShape.size();
+  unsigned srcRank = srcShape.size();
+  if (srcRank > dstRank)
+    return failure();
+
+  for (unsigned i = 0; i < dstRank; ++i) {
+    int64_t dstDim = dstShape[i];
+    int64_t srcDim = 1;
+    if (i + srcRank >= dstRank)
+      srcDim = srcShape[i + srcRank - dstRank];
+    if (!(srcDim == 1 || srcDim == dstDim))
+      return failure();
+  }
+
+  int64_t dstCount = 1;
+  for (int64_t v : dstShape)
+    dstCount *= v;
+  results.clear();
+  results.reserve(dstCount);
+
+  // Pre-compute the mapping from destination linear index to source linear index
+  SmallVector<int64_t> dstToSrcMap(dstCount);
+  SmallVector<int64_t> dstIdx(dstShape.size());
+  
+  for (int64_t lin = 0; lin < dstCount; ++lin) {
+    // Convert linear index to multi-dimensional indices (row-major order)
+    int64_t temp = lin;
+    for (int64_t i = dstShape.size() - 1; i >= 0; --i) {
+      int64_t dim = dstShape[i];
+      dstIdx[i] = temp % dim;
+      temp /= dim;
+    }
+    // Right-align mapping from dst indices to src indices.
+    int64_t srcLin = 0;
+    for (unsigned k = 0; k < srcRank; ++k)
+      srcLin = srcLin * srcShape[k] + 
+        ((srcShape[k] == 1) ? 0 : dstIdx[dstRank - srcRank + k]);
+
+    dstToSrcMap[lin] = srcLin;
+  }
+
+  // Apply the pre-computed mapping
+  for (int64_t lin = 0; lin < dstCount; ++lin) {
+    results.push_back(srcElems.getResult(dstToSrcMap[lin]));
+  }
+  return success();
+}
+
 LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
                                  SmallVectorImpl<OpFoldResult> &results) {
-  return foldToElementsFromElements(*this, results);
+  if (succeeded(foldToElementsFromElements(*this, results)))
+    return success();
+  return foldToElementsOfBroadcast(*this, results);
 }
 
+
 LogicalResult
 ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
                                ToElementsOp::Adaptor adaptor,
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 08d28be3f8f73..728c4ddd22ec7 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -3326,6 +3326,32 @@ func.func @from_elements_to_elements_shuffle(%a: vector<4x2xf32>) -> vector<4x2x
 
 // -----
 
+// CHECK-LABEL: func @to_elements_of_scalar_broadcast_folds
+// CHECK-SAME: (%[[S:.*]]: f32) -> (f32, f32, f32, f32)
+func.func @to_elements_of_scalar_broadcast_folds(%s: f32) -> (f32, f32, f32, f32) {
+  %v = vector.broadcast %s : f32 to vector<4xf32>
+  %e:4 = vector.to_elements %v : vector<4xf32>
+  // CHECK-NOT: vector.broadcast
+  // CHECK-NOT: vector.to_elements
+  // CHECK: return %[[S]], %[[S]], %[[S]], %[[S]]
+  return %e#0, %e#1, %e#2, %e#3 : f32, f32, f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: func @to_elements_of_vector_broadcast
+// CHECK-SAME: (%[[VEC:.*]]: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32)
+func.func @to_elements_of_vector_broadcast(%vec: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32) {
+  %v = vector.broadcast %vec : vector<2xf32> to vector<3x2xf32>
+  %e:6 = vector.to_elements %v : vector<3x2xf32>
+  // CHECK-NOT: vector.broadcast
+  // CHECK: %[[SRC_ELEMS:.*]]:2 = vector.to_elements %[[VEC]]
+  // CHECK: return %[[SRC_ELEMS]]#0, %[[SRC_ELEMS]]#1, %[[SRC_ELEMS]]#0, %[[SRC_ELEMS]]#1, %[[SRC_ELEMS]]#0, %[[SRC_ELEMS]]#1
+  return %e#0, %e#1, %e#2, %e#3, %e#4, %e#5 : f32, f32, f32, f32, f32, f32
+}
+
+// -----
+
 // +---------------------------------------------------------------------------
 // Tests for foldFromElementsToConstant
 // +---------------------------------------------------------------------------
    
    
More information about the Mlir-commits
mailing list