[Mlir-commits] [mlir] [MLIR][Vector] Added ToElementsOp::fold for broadcast->to_elements pattern rewrite. (PR #160318)

Keshav Vinayak Jha llvmlistbot at llvm.org
Tue Sep 30 07:16:21 PDT 2025


https://github.com/keshavvinayak01 updated https://github.com/llvm/llvm-project/pull/160318

>From 691bf891868130e8eb66953de2648e1f1befcd31 Mon Sep 17 00:00:00 2001
From: keshavvinayak01 <keshavvinayakjha at gmail.com>
Date: Tue, 23 Sep 2025 14:59:27 +0000
Subject: [PATCH 1/6] Added folder for broadcast->to_elements

Signed-off-by: keshavvinayak01 <keshavvinayakjha at gmail.com>
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp   | 94 +++++++++++++++++++++-
 mlir/test/Dialect/Vector/canonicalize.mlir | 26 ++++++
 2 files changed, 119 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 347141e2773b8..4ac61418b97a5 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2395,11 +2395,103 @@ foldToElementsFromElements(ToElementsOp toElementsOp,
   return success();
 }
 
+/// Folds vector.to_elements(vector.broadcast(%x)) by creating a new
+/// vector.to_elements on the source and remapping results according to
+/// broadcast semantics.
+///
+/// Cases handled:
+///  - %x is a scalar: replicate the scalar across all results.
+///  - %x is a vector: create to_elements on source and remap/duplicate results.
+static LogicalResult
+foldToElementsOfBroadcast(ToElementsOp toElementsOp,
+                          SmallVectorImpl<OpFoldResult> &results) {
+  auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
+  if (!bcastOp)
+    return failure();
+
+  auto resultVecType = cast<VectorType>(toElementsOp.getSource().getType());
+  // Bail on scalable vectors.
+  if (resultVecType.getNumScalableDims() != 0)
+    return failure();
+
+  // Case 1: scalar broadcast → replicate scalar across all results.
+  if (!isa<VectorType>(bcastOp.getSource().getType())) {
+    Value scalar = bcastOp.getSource();
+    results.assign(resultVecType.getNumElements(), scalar);
+    return success();
+  }
+
+  // Case 2: vector broadcast → create to_elements on source and remap.
+  auto srcVecType = cast<VectorType>(bcastOp.getSource().getType());
+  if (srcVecType.getNumScalableDims() != 0)
+    return failure();
+
+  // Create a temporary to_elements to get the source elements for mapping.
+  // Change the operand to the broadcast source.
+  OpBuilder builder(toElementsOp);
+  auto srcElems = builder.create<ToElementsOp>(toElementsOp.getLoc(),
+                                               bcastOp.getSource());
+
+  ArrayRef<int64_t> dstShape = resultVecType.getShape();
+  ArrayRef<int64_t> srcShape = srcVecType.getShape();
+
+  // Quick broadcastability check with right-aligned shapes.
+  unsigned dstRank = dstShape.size();
+  unsigned srcRank = srcShape.size();
+  if (srcRank > dstRank)
+    return failure();
+
+  for (unsigned i = 0; i < dstRank; ++i) {
+    int64_t dstDim = dstShape[i];
+    int64_t srcDim = 1;
+    if (i + srcRank >= dstRank)
+      srcDim = srcShape[i + srcRank - dstRank];
+    if (!(srcDim == 1 || srcDim == dstDim))
+      return failure();
+  }
+
+  int64_t dstCount = 1;
+  for (int64_t v : dstShape)
+    dstCount *= v;
+  results.clear();
+  results.reserve(dstCount);
+
+  // Pre-compute the mapping from destination linear index to source linear index
+  SmallVector<int64_t> dstToSrcMap(dstCount);
+  SmallVector<int64_t> dstIdx(dstShape.size());
+  
+  for (int64_t lin = 0; lin < dstCount; ++lin) {
+    // Convert linear index to multi-dimensional indices (row-major order)
+    int64_t temp = lin;
+    for (int64_t i = dstShape.size() - 1; i >= 0; --i) {
+      int64_t dim = dstShape[i];
+      dstIdx[i] = temp % dim;
+      temp /= dim;
+    }
+    // Right-align mapping from dst indices to src indices.
+    int64_t srcLin = 0;
+    for (unsigned k = 0; k < srcRank; ++k)
+      srcLin = srcLin * srcShape[k] + 
+        ((srcShape[k] == 1) ? 0 : dstIdx[dstRank - srcRank + k]);
+
+    dstToSrcMap[lin] = srcLin;
+  }
+
+  // Apply the pre-computed mapping
+  for (int64_t lin = 0; lin < dstCount; ++lin) {
+    results.push_back(srcElems.getResult(dstToSrcMap[lin]));
+  }
+  return success();
+}
+
 LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
                                  SmallVectorImpl<OpFoldResult> &results) {
-  return foldToElementsFromElements(*this, results);
+  if (succeeded(foldToElementsFromElements(*this, results)))
+    return success();
+  return foldToElementsOfBroadcast(*this, results);
 }
 
+
 LogicalResult
 ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
                                ToElementsOp::Adaptor adaptor,
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 08d28be3f8f73..728c4ddd22ec7 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -3326,6 +3326,32 @@ func.func @from_elements_to_elements_shuffle(%a: vector<4x2xf32>) -> vector<4x2x
 
 // -----
 
+// CHECK-LABEL: func @to_elements_of_scalar_broadcast_folds
+// CHECK-SAME: (%[[S:.*]]: f32) -> (f32, f32, f32, f32)
+func.func @to_elements_of_scalar_broadcast_folds(%s: f32) -> (f32, f32, f32, f32) {
+  %v = vector.broadcast %s : f32 to vector<4xf32>
+  %e:4 = vector.to_elements %v : vector<4xf32>
+  // CHECK-NOT: vector.broadcast
+  // CHECK-NOT: vector.to_elements
+  // CHECK: return %[[S]], %[[S]], %[[S]], %[[S]]
+  return %e#0, %e#1, %e#2, %e#3 : f32, f32, f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: func @to_elements_of_vector_broadcast
+// CHECK-SAME: (%[[VEC:.*]]: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32)
+func.func @to_elements_of_vector_broadcast(%vec: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32) {
+  %v = vector.broadcast %vec : vector<2xf32> to vector<3x2xf32>
+  %e:6 = vector.to_elements %v : vector<3x2xf32>
+  // CHECK-NOT: vector.broadcast
+  // CHECK: %[[SRC_ELEMS:.*]]:2 = vector.to_elements %[[VEC]]
+  // CHECK: return %[[SRC_ELEMS]]#0, %[[SRC_ELEMS]]#1, %[[SRC_ELEMS]]#0, %[[SRC_ELEMS]]#1, %[[SRC_ELEMS]]#0, %[[SRC_ELEMS]]#1
+  return %e#0, %e#1, %e#2, %e#3, %e#4, %e#5 : f32, f32, f32, f32, f32, f32
+}
+
+// -----
+
 // +---------------------------------------------------------------------------
 // Tests for foldFromElementsToConstant
 // +---------------------------------------------------------------------------

>From c80ccba208a3f2fcf7f7f9b92b03b60fa9168cde Mon Sep 17 00:00:00 2001
From: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
Date: Tue, 30 Sep 2025 01:54:45 -0700
Subject: [PATCH 2/6] Vector broadcast->to_elements requires a canonicalizer

Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
---
 .../mlir/Dialect/Vector/IR/VectorOps.td       |   1 +
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 153 ++++++++++--------
 mlir/test/Dialect/Vector/canonicalize.mlir    |  25 +++
 3 files changed, 113 insertions(+), 66 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 26d06624cb976..61fde8c0ba8c5 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -805,6 +805,7 @@ def Vector_ToElementsOp : Vector_Op<"to_elements", [
   let results = (outs Variadic<AnyType>:$elements);
   let assemblyFormat = "$source attr-dict `:` type($source)";
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 def Vector_FromElementsOp : Vector_Op<"from_elements", [
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 4ac61418b97a5..fa40af6b16528 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2395,13 +2395,12 @@ foldToElementsFromElements(ToElementsOp toElementsOp,
   return success();
 }
 
-/// Folds vector.to_elements(vector.broadcast(%x)) by creating a new
-/// vector.to_elements on the source and remapping results according to
-/// broadcast semantics.
+/// Folds vector.to_elements(vector.broadcast(%x)) for the scalar case only.
 ///
 /// Cases handled:
 ///  - %x is a scalar: replicate the scalar across all results.
-///  - %x is a vector: create to_elements on source and remap/duplicate results.
+///
+/// The vector source case is handled by a canonicalization pattern.
 static LogicalResult
 foldToElementsOfBroadcast(ToElementsOp toElementsOp,
                           SmallVectorImpl<OpFoldResult> &results) {
@@ -2421,67 +2420,8 @@ foldToElementsOfBroadcast(ToElementsOp toElementsOp,
     return success();
   }
 
-  // Case 2: vector broadcast → create to_elements on source and remap.
-  auto srcVecType = cast<VectorType>(bcastOp.getSource().getType());
-  if (srcVecType.getNumScalableDims() != 0)
-    return failure();
-
-  // Create a temporary to_elements to get the source elements for mapping.
-  // Change the operand to the broadcast source.
-  OpBuilder builder(toElementsOp);
-  auto srcElems = builder.create<ToElementsOp>(toElementsOp.getLoc(),
-                                               bcastOp.getSource());
-
-  ArrayRef<int64_t> dstShape = resultVecType.getShape();
-  ArrayRef<int64_t> srcShape = srcVecType.getShape();
-
-  // Quick broadcastability check with right-aligned shapes.
-  unsigned dstRank = dstShape.size();
-  unsigned srcRank = srcShape.size();
-  if (srcRank > dstRank)
-    return failure();
-
-  for (unsigned i = 0; i < dstRank; ++i) {
-    int64_t dstDim = dstShape[i];
-    int64_t srcDim = 1;
-    if (i + srcRank >= dstRank)
-      srcDim = srcShape[i + srcRank - dstRank];
-    if (!(srcDim == 1 || srcDim == dstDim))
-      return failure();
-  }
-
-  int64_t dstCount = 1;
-  for (int64_t v : dstShape)
-    dstCount *= v;
-  results.clear();
-  results.reserve(dstCount);
-
-  // Pre-compute the mapping from destination linear index to source linear index
-  SmallVector<int64_t> dstToSrcMap(dstCount);
-  SmallVector<int64_t> dstIdx(dstShape.size());
-  
-  for (int64_t lin = 0; lin < dstCount; ++lin) {
-    // Convert linear index to multi-dimensional indices (row-major order)
-    int64_t temp = lin;
-    for (int64_t i = dstShape.size() - 1; i >= 0; --i) {
-      int64_t dim = dstShape[i];
-      dstIdx[i] = temp % dim;
-      temp /= dim;
-    }
-    // Right-align mapping from dst indices to src indices.
-    int64_t srcLin = 0;
-    for (unsigned k = 0; k < srcRank; ++k)
-      srcLin = srcLin * srcShape[k] + 
-        ((srcShape[k] == 1) ? 0 : dstIdx[dstRank - srcRank + k]);
-
-    dstToSrcMap[lin] = srcLin;
-  }
-
-  // Apply the pre-computed mapping
-  for (int64_t lin = 0; lin < dstCount; ++lin) {
-    results.push_back(srcElems.getResult(dstToSrcMap[lin]));
-  }
-  return success();
+  // Vector source case is not folded here.
+  return failure();
 }
 
 LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
@@ -2491,7 +2431,6 @@ LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
   return foldToElementsOfBroadcast(*this, results);
 }
 
-
 LogicalResult
 ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
                                ToElementsOp::Adaptor adaptor,
@@ -2502,6 +2441,88 @@ ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
   return success();
 }
 
+namespace {
+
+struct ToElementsOfVectorBroadcast final
+    : public OpRewritePattern<ToElementsOp> {
+  using OpRewritePattern<ToElementsOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(ToElementsOp toElementsOp,
+                                PatternRewriter &rewriter) const override {
+    auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
+    if (!bcastOp)
+      return failure();
+
+    // Only handle broadcasts from a vector source here.
+    auto srcType = dyn_cast<VectorType>(bcastOp.getSource().getType());
+    if (!srcType)
+      return failure();
+
+    auto dstType = cast<VectorType>(toElementsOp.getSource().getType());
+
+    // Bail on scalable vectors.
+    if (srcType.getNumScalableDims() != 0 || dstType.getNumScalableDims() != 0)
+      return failure();
+
+    ArrayRef<int64_t> dstShape = dstType.getShape();
+    ArrayRef<int64_t> srcShape = srcType.getShape();
+
+    unsigned dstRank = dstShape.size();
+    unsigned srcRank = srcShape.size();
+    if (srcRank > dstRank)
+      return failure();
+
+    // Verify broadcastability (right-aligned)
+    for (unsigned i = 0; i < dstRank; ++i) {
+      int64_t dstDim = dstShape[i];
+      int64_t srcDim = 1;
+      if (i + srcRank >= dstRank)
+        srcDim = srcShape[i + srcRank - dstRank];
+      if (!(srcDim == 1 || srcDim == dstDim))
+        return failure();
+    }
+
+    // Create elements for the broadcast source vector.
+    auto loc = toElementsOp.getLoc();
+    auto srcElems = rewriter.create<ToElementsOp>(loc, bcastOp.getSource());
+
+    int64_t dstCount = 1;
+    for (int64_t v : dstShape)
+      dstCount *= v;
+
+    SmallVector<Value> replacements;
+    replacements.reserve(dstCount);
+
+    // Pre-compute and apply mapping from destination linear index to
+    // source linear index (row-major, right-aligned broadcasting).
+    SmallVector<int64_t> dstIdx(dstShape.size());
+    for (int64_t lin = 0; lin < dstCount; ++lin) {
+      int64_t temp = lin;
+      for (int64_t i = dstShape.size() - 1; i >= 0; --i) {
+        int64_t dim = dstShape[i];
+        dstIdx[i] = temp % dim;
+        temp /= dim;
+      }
+      int64_t srcLin = 0;
+      for (unsigned k = 0; k < srcRank; ++k)
+        srcLin = srcLin * srcShape[k] +
+                 ((srcShape[k] == 1) ? 0 : dstIdx[dstRank - srcRank + k]);
+
+      replacements.push_back(srcElems.getResult(srcLin));
+    }
+
+    rewriter.replaceOp(toElementsOp, replacements);
+    return success();
+  }
+};
+
+} // end anonymous namespace
+
+void ToElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                               MLIRContext *context) {
+  results.add<ToElementsOfVectorBroadcast>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // FromElementsOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 728c4ddd22ec7..ab6d2ec9a835b 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -3816,3 +3816,28 @@ func.func @no_fold_insert_use_chain_mismatch_static_position(%arg : vector<4xf32
   %v_1 = vector.insert %val, %v_0[1] : f32 into vector<4xf32>
   return %v_1 : vector<4xf32>
 }
+
+// -----
+
+// CHECK-LABEL: @fold_to_elements_of_scalar_broadcast
+//  CHECK-SAME: (%[[S:.*]]: f32) -> (f32, f32, f32, f32)
+func.func @fold_to_elements_of_scalar_broadcast(%s: f32) -> (f32, f32, f32, f32) {
+  %v = vector.broadcast %s : f32 to vector<4xf32>
+  %e:4 = vector.to_elements %v : vector<4xf32>
+  // CHECK-NOT: vector.broadcast
+  // CHECK: return %[[S]], %[[S]], %[[S]], %[[S]] : f32, f32, f32, f32
+  return %e#0, %e#1, %e#2, %e#3 : f32, f32, f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: @canonicalize_to_elements_of_vector_broadcast
+//  CHECK-SAME: (%[[VEC:.*]]: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32)
+func.func @canonicalize_to_elements_of_vector_broadcast(%vec: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32) {
+  %v = vector.broadcast %vec : vector<2xf32> to vector<3x2xf32>
+  %e:6 = vector.to_elements %v : vector<3x2xf32>
+  // CHECK-NOT: vector.broadcast
+  // CHECK: %[[SRC_ELEMS:.*]]:2 = vector.to_elements %[[VEC]]
+  // CHECK: return %[[SRC_ELEMS]]#0, %[[SRC_ELEMS]]#1, %[[SRC_ELEMS]]#0, %[[SRC_ELEMS]]#1, %[[SRC_ELEMS]]#0, %[[SRC_ELEMS]]#1
+  return %e#0, %e#1, %e#2, %e#3, %e#4, %e#5 : f32, f32, f32, f32, f32, f32
+}

>From 5d04df4fbb5c4bee4093a4024414685f53ec9843 Mon Sep 17 00:00:00 2001
From: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
Date: Tue, 30 Sep 2025 01:59:50 -0700
Subject: [PATCH 3/6] Removed duplicate lit

Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
---
 mlir/test/Dialect/Vector/canonicalize.mlir | 25 ----------------------
 1 file changed, 25 deletions(-)

diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index ab6d2ec9a835b..728c4ddd22ec7 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -3816,28 +3816,3 @@ func.func @no_fold_insert_use_chain_mismatch_static_position(%arg : vector<4xf32
   %v_1 = vector.insert %val, %v_0[1] : f32 into vector<4xf32>
   return %v_1 : vector<4xf32>
 }
-
-// -----
-
-// CHECK-LABEL: @fold_to_elements_of_scalar_broadcast
-//  CHECK-SAME: (%[[S:.*]]: f32) -> (f32, f32, f32, f32)
-func.func @fold_to_elements_of_scalar_broadcast(%s: f32) -> (f32, f32, f32, f32) {
-  %v = vector.broadcast %s : f32 to vector<4xf32>
-  %e:4 = vector.to_elements %v : vector<4xf32>
-  // CHECK-NOT: vector.broadcast
-  // CHECK: return %[[S]], %[[S]], %[[S]], %[[S]] : f32, f32, f32, f32
-  return %e#0, %e#1, %e#2, %e#3 : f32, f32, f32, f32
-}
-
-// -----
-
-// CHECK-LABEL: @canonicalize_to_elements_of_vector_broadcast
-//  CHECK-SAME: (%[[VEC:.*]]: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32)
-func.func @canonicalize_to_elements_of_vector_broadcast(%vec: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32) {
-  %v = vector.broadcast %vec : vector<2xf32> to vector<3x2xf32>
-  %e:6 = vector.to_elements %v : vector<3x2xf32>
-  // CHECK-NOT: vector.broadcast
-  // CHECK: %[[SRC_ELEMS:.*]]:2 = vector.to_elements %[[VEC]]
-  // CHECK: return %[[SRC_ELEMS]]#0, %[[SRC_ELEMS]]#1, %[[SRC_ELEMS]]#0, %[[SRC_ELEMS]]#1, %[[SRC_ELEMS]]#0, %[[SRC_ELEMS]]#1
-  return %e#0, %e#1, %e#2, %e#3, %e#4, %e#5 : f32, f32, f32, f32, f32, f32
-}

>From 728f51a605324679f1d1cf07daabc307e1b2773f Mon Sep 17 00:00:00 2001
From: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
Date: Tue, 30 Sep 2025 04:07:32 -0700
Subject: [PATCH 4/6] > Addressed review comments: 1. Better comments, removing
 keywords and verbose docs. 2. Removed redundant "Broadcastability" check, we
 don't require it since the vector.broadcast op will always be valid when it
 reaches this logic.

Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 49 ++++++++++--------------
 1 file changed, 20 insertions(+), 29 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index fa40af6b16528..04a6fbd74245f 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -47,6 +47,7 @@
 
 #include <cassert>
 #include <cstdint>
+#include <numeric>
 
 #include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
 // Pull in all enum type and utility function definitions.
@@ -2409,7 +2410,8 @@ foldToElementsOfBroadcast(ToElementsOp toElementsOp,
     return failure();
 
   auto resultVecType = cast<VectorType>(toElementsOp.getSource().getType());
-  // Bail on scalable vectors.
+  // Bail on scalable vectors, since the element count and per-dimension extents
+  // must be known at compile time.
   if (resultVecType.getNumScalableDims() != 0)
     return failure();
 
@@ -2441,11 +2443,8 @@ ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
   return success();
 }
 
-namespace {
-
-struct ToElementsOfVectorBroadcast final
-    : public OpRewritePattern<ToElementsOp> {
-  using OpRewritePattern<ToElementsOp>::OpRewritePattern;
+class ToElementsOfBroadcast final : public OpRewritePattern<ToElementsOp> {
+  using OpRewritePattern::OpRewritePattern;
 
   LogicalResult matchAndRewrite(ToElementsOp toElementsOp,
                                 PatternRewriter &rewriter) const override {
@@ -2467,34 +2466,28 @@ struct ToElementsOfVectorBroadcast final
     ArrayRef<int64_t> dstShape = dstType.getShape();
     ArrayRef<int64_t> srcShape = srcType.getShape();
 
-    unsigned dstRank = dstShape.size();
-    unsigned srcRank = srcShape.size();
+    int64_t dstRank = dstShape.size();
+    int64_t srcRank = srcShape.size();
     if (srcRank > dstRank)
       return failure();
 
-    // Verify broadcastability (right-aligned)
-    for (unsigned i = 0; i < dstRank; ++i) {
-      int64_t dstDim = dstShape[i];
-      int64_t srcDim = 1;
-      if (i + srcRank >= dstRank)
-        srcDim = srcShape[i + srcRank - dstRank];
-      if (!(srcDim == 1 || srcDim == dstDim))
-        return failure();
-    }
-
     // Create elements for the broadcast source vector.
-    auto loc = toElementsOp.getLoc();
-    auto srcElems = rewriter.create<ToElementsOp>(loc, bcastOp.getSource());
+    auto srcElems = rewriter.create<ToElementsOp>(toElementsOp.getLoc(),
+                                                  bcastOp.getSource());
 
-    int64_t dstCount = 1;
-    for (int64_t v : dstShape)
-      dstCount *= v;
+    int64_t dstCount = std::accumulate(dstShape.begin(), dstShape.end(), 1,
+                                       std::multiplies<int64_t>());
 
     SmallVector<Value> replacements;
     replacements.reserve(dstCount);
 
-    // Pre-compute and apply mapping from destination linear index to
-    // source linear index (row-major, right-aligned broadcasting).
+    // For each element of the destination, determine which element of the
+    // source should be used. We walk all destination positions using a single
+    // counter, decode it into per-dimension indices, then build the matching
+    // source position: use the same index where sizes match, and use 0 where
+    // the source size is 1 (replication). This mapping is needed so we can
+    // replace each result of to_elements with the corresponding element from
+    // the broadcast source.
     SmallVector<int64_t> dstIdx(dstShape.size());
     for (int64_t lin = 0; lin < dstCount; ++lin) {
       int64_t temp = lin;
@@ -2504,7 +2497,7 @@ struct ToElementsOfVectorBroadcast final
         temp /= dim;
       }
       int64_t srcLin = 0;
-      for (unsigned k = 0; k < srcRank; ++k)
+      for (int64_t k = 0; k < srcRank; ++k)
         srcLin = srcLin * srcShape[k] +
                  ((srcShape[k] == 1) ? 0 : dstIdx[dstRank - srcRank + k]);
 
@@ -2516,11 +2509,9 @@ struct ToElementsOfVectorBroadcast final
   }
 };
 
-} // end anonymous namespace
-
 void ToElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                MLIRContext *context) {
-  results.add<ToElementsOfVectorBroadcast>(context);
+  results.add<ToElementsOfBroadcast>(context);
 }
 
 //===----------------------------------------------------------------------===//

>From f9b44a9f9af6f8b8fdc8acd5efb0c5c2de01d4d7 Mon Sep 17 00:00:00 2001
From: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
Date: Tue, 30 Sep 2025 04:33:59 -0700
Subject: [PATCH 5/6] Added doc for the canonicalizer

Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 14 ++++++++++++++
 1 file changed, 14 insertions(+)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 04a6fbd74245f..f2cba82ff73c1 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2443,6 +2443,20 @@ ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
   return success();
 }
 
+/// Canonicalize `vector.to_elements(vector.broadcast(%v))` where `%v` is a
+/// vector:
+/// - Build `vector.to_elements %v` and remap each destination element to the
+///   corresponding source element using broadcast rules (match or 1 →
+///   replicate).
+///
+/// Example:
+///   %v = vector.broadcast %src : vector<2xf32> to vector<3x2xf32>
+///   %e:6 = vector.to_elements %v : vector<3x2xf32>
+/// becomes:
+///   %src_elems:2 = vector.to_elements %src : vector<2xf32>
+///   // uses: %src_elems#0, %src_elems#1, %src_elems#0,
+///   //       %src_elems#1, %src_elems#0, %src_elems#1
+
 class ToElementsOfBroadcast final : public OpRewritePattern<ToElementsOp> {
   using OpRewritePattern::OpRewritePattern;
 

>From 18675dc2109387c7a4cc7618aa87279b457ddba2 Mon Sep 17 00:00:00 2001
From: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
Date: Tue, 30 Sep 2025 07:13:35 -0700
Subject: [PATCH 6/6] Addressed comments: 1. Added better docs for the folder
 with an example 2. Removed isScalable check, not required for toElementsOp 3.
 Used free create method for new toElementsOp

Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 38 ++++++++++--------------
 1 file changed, 15 insertions(+), 23 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index f2cba82ff73c1..015ad5d3b44bb 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2398,8 +2398,13 @@ foldToElementsFromElements(ToElementsOp toElementsOp,
 
 /// Folds vector.to_elements(vector.broadcast(%x)) for the scalar case only.
 ///
-/// Cases handled:
-///  - %x is a scalar: replicate the scalar across all results.
+/// takes a scalar %x: replicate the scalar across all results.
+/// Example:
+///  %b = vector.broadcast %x : i32 to vector<3xf32>
+///  %e:3 = vector.to_elements %b : vector<3xf32>
+///  user_op %e#0, %e#1, %e#2
+/// becomes:
+///  user_op %x, %x, %x
 ///
 /// The vector source case is handled by a canonicalization pattern.
 static LogicalResult
@@ -2408,22 +2413,15 @@ foldToElementsOfBroadcast(ToElementsOp toElementsOp,
   auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
   if (!bcastOp)
     return failure();
-
-  auto resultVecType = cast<VectorType>(toElementsOp.getSource().getType());
-  // Bail on scalable vectors, since the element count and per-dimension extents
-  // must be known at compile time.
-  if (resultVecType.getNumScalableDims() != 0)
+  // Vectors are handled in the ToElementsOfBroadcast RewritePattern.
+  if (isa<VectorType>(bcastOp.getSource().getType()))
     return failure();
 
-  // Case 1: scalar broadcast → replicate scalar across all results.
-  if (!isa<VectorType>(bcastOp.getSource().getType())) {
-    Value scalar = bcastOp.getSource();
-    results.assign(resultVecType.getNumElements(), scalar);
-    return success();
-  }
+  auto resultVecType = cast<VectorType>(toElementsOp.getSource().getType());
 
-  // Vector source case is not folded here.
-  return failure();
+  Value scalar = bcastOp.getSource();
+  results.assign(resultVecType.getNumElements(), scalar);
+  return success();
 }
 
 LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
@@ -2473,21 +2471,15 @@ class ToElementsOfBroadcast final : public OpRewritePattern<ToElementsOp> {
 
     auto dstType = cast<VectorType>(toElementsOp.getSource().getType());
 
-    // Bail on scalable vectors.
-    if (srcType.getNumScalableDims() != 0 || dstType.getNumScalableDims() != 0)
-      return failure();
-
     ArrayRef<int64_t> dstShape = dstType.getShape();
     ArrayRef<int64_t> srcShape = srcType.getShape();
 
     int64_t dstRank = dstShape.size();
     int64_t srcRank = srcShape.size();
-    if (srcRank > dstRank)
-      return failure();
 
     // Create elements for the broadcast source vector.
-    auto srcElems = rewriter.create<ToElementsOp>(toElementsOp.getLoc(),
-                                                  bcastOp.getSource());
+    auto srcElems = vector::ToElementsOp::create(
+        rewriter, toElementsOp.getLoc(), bcastOp.getSource());
 
     int64_t dstCount = std::accumulate(dstShape.begin(), dstShape.end(), 1,
                                        std::multiplies<int64_t>());



More information about the Mlir-commits mailing list