[Mlir-commits] [mlir] 805451f - [MLIR][Vector] Added ToElementsOp::fold for broadcast->to_elements pattern rewrite. (#160318)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Oct 8 05:13:50 PDT 2025


Author: Keshav Vinayak Jha
Date: 2025-10-08T08:13:46-04:00
New Revision: 805451faedabf0e227b51683d43ebe3093115e58

URL: https://github.com/llvm/llvm-project/commit/805451faedabf0e227b51683d43ebe3093115e58
DIFF: https://github.com/llvm/llvm-project/commit/805451faedabf0e227b51683d43ebe3093115e58.diff

LOG: [MLIR][Vector] Added ToElementsOp::fold for broadcast->to_elements pattern rewrite. (#160318)

Adds `::fold` for the new `vector.to_elements` op, folding `broadcast`
into `to_elements` or no-op wherever possible.

---------

Signed-off-by: keshavvinayak01 <keshavvinayakjha at gmail.com>
Signed-off-by: Keshav Vinayak Jha <keshavvinayakjha at gmail.com>
Co-authored-by: Jakub Kuderski <kubakuderski at gmail.com>

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp
    mlir/test/Dialect/Vector/canonicalize.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 252c0b72456df..bb5d686fdd4d2 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -806,6 +806,7 @@ def Vector_ToElementsOp : Vector_Op<"to_elements", [
   let results = (outs Variadic<AnyType>:$elements);
   let assemblyFormat = "$source attr-dict `:` type($source)";
   let hasFolder = 1;
+  let hasCanonicalizer = 1;
 }
 
 def Vector_FromElementsOp : Vector_Op<"from_elements", [

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index b0132e889302f..14e235f253f69 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -47,6 +47,7 @@
 
 #include <cassert>
 #include <cstdint>
+#include <numeric>
 
 #include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
 // Pull in all enum type and utility function definitions.
@@ -2412,9 +2413,38 @@ foldToElementsFromElements(ToElementsOp toElementsOp,
   return success();
 }
 
+/// Folds vector.to_elements(vector.broadcast(%x)) for the scalar case only.
+///
+/// Example:
+///  %b = vector.broadcast %x : i32 to vector<3xf32>
+///  %e:3 = vector.to_elements %b : vector<3xf32>
+///  user_op %e#0, %e#1, %e#2
+/// becomes:
+///  user_op %x, %x, %x
+///
+/// The vector source case is handled by a canonicalization pattern.
+static LogicalResult
+foldToElementsOfBroadcast(ToElementsOp toElementsOp,
+                          SmallVectorImpl<OpFoldResult> &results) {
+  auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
+  if (!bcastOp)
+    return failure();
+  // Vectors are handled in the ToElementsOfBroadcast RewritePattern.
+  if (isa<VectorType>(bcastOp.getSource().getType()))
+    return failure();
+
+  auto resultVecType = cast<VectorType>(toElementsOp.getSource().getType());
+
+  Value scalar = bcastOp.getSource();
+  results.assign(resultVecType.getNumElements(), scalar);
+  return success();
+}
+
 LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
                                  SmallVectorImpl<OpFoldResult> &results) {
-  return foldToElementsFromElements(*this, results);
+  if (succeeded(foldToElementsFromElements(*this, results)))
+    return success();
+  return foldToElementsOfBroadcast(*this, results);
 }
 
 LogicalResult
@@ -2427,6 +2457,94 @@ ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
   return success();
 }
 
+/// Canonicalize `vector.to_elements(vector.broadcast(%v))` where `%v` is a
+/// vector.
+/// - Build `vector.to_elements %v` and remap each destination element to the
+///   corresponding source element using broadcast rules (match or 1 →
+///   replicate).
+///
+/// Example:
+///   %v = vector.broadcast %src : vector<2xf32> to vector<3x2xf32>
+///   %e:6 = vector.to_elements %v : vector<3x2xf32>
+/// becomes:
+///   %src_elems:2 = vector.to_elements %src : vector<2xf32>
+///   // uses: %src_elems#0, %src_elems#1, %src_elems#0,
+///   //       %src_elems#1, %src_elems#0, %src_elems#1
+struct ToElementsOfBroadcast final : OpRewritePattern<ToElementsOp> {
+  using Base::Base;
+
+  LogicalResult matchAndRewrite(ToElementsOp toElementsOp,
+                                PatternRewriter &rewriter) const override {
+    auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
+    if (!bcastOp)
+      return failure();
+
+    // Only handle broadcasts from a vector source here.
+    auto srcType = dyn_cast<VectorType>(bcastOp.getSource().getType());
+    if (!srcType)
+      return failure();
+
+    auto dstType = cast<VectorType>(toElementsOp.getSource().getType());
+
+    ArrayRef<int64_t> dstShape = dstType.getShape();
+    ArrayRef<int64_t> srcShape = srcType.getShape();
+
+    int64_t dstRank = dstShape.size();
+    int64_t srcRank = srcShape.size();
+
+    // Create elements for the broadcast source vector.
+    auto srcElems = vector::ToElementsOp::create(
+        rewriter, toElementsOp.getLoc(), bcastOp.getSource());
+
+    int64_t dstCount = std::accumulate(dstShape.begin(), dstShape.end(), 1,
+                                       std::multiplies<int64_t>());
+
+    SmallVector<Value> replacements;
+    replacements.reserve(dstCount);
+
+    // For each element of the destination, determine which element of the
+    // source should be used. We walk all destination positions using a single
+    // counter, decode it into per-dimension indices, then build the matching
+    // source position: use the same index where sizes match, and use 0 where
+    // the source size is 1 (replication). This mapping is needed so we can
+    // replace each result of to_elements with the corresponding element from
+    // the broadcast source.
+    // Inner-dimension stretch example:
+    //   %v = vector.broadcast %src : vector<2x1x2xf32> to vector<2x3x2xf32>
+    //   %e:12 = vector.to_elements %v : vector<2x3x2xf32>
+    // becomes:
+    //   %src_elems:4 = vector.to_elements %src : vector<2x1x2xf32>
+    //   // uses: %src_elems#0, %src_elems#1, %src_elems#0,
+    //   //       %src_elems#1, %src_elems#0, %src_elems#1,
+    //   //       %src_elems#2, %src_elems#3, %src_elems#2,
+    //   //       %src_elems#3, %src_elems#2, %src_elems#3
+
+    // Row-major strides for the destination shape.
+    SmallVector<int64_t> dstStrides = computeStrides(dstShape);
+    // Row-major strides for the source shape.
+    SmallVector<int64_t> srcStrides = computeStrides(srcShape);
+    SmallVector<int64_t> dstIdx(dstRank);
+    SmallVector<int64_t> srcIdx(srcRank);
+    for (int64_t lin = 0; lin < dstCount; ++lin) {
+      // Convert linear destination index to per-dimension indices.
+      dstIdx = delinearize(lin, dstStrides);
+      for (int64_t k = 0; k < srcRank; ++k)
+        srcIdx[k] = (srcShape[k] == 1) ? 0 : dstIdx[dstRank - srcRank + k];
+      // Convert per-dimension source indices back to a linear index.
+      int64_t srcLin = linearize(srcIdx, srcStrides);
+      replacements.push_back(srcElems.getResult(srcLin));
+    }
+
+    rewriter.replaceOp(toElementsOp, replacements);
+    return success();
+  }
+};
+
+void ToElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                               MLIRContext *context) {
+  results.add<ToElementsOfBroadcast>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // FromElementsOp
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index c07edacea9975..eb369c07c3e0d 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -3322,6 +3322,46 @@ func.func @from_elements_to_elements_shuffle(%a: vector<4x2xf32>) -> vector<4x2x
 
 // -----
 
+// CHECK-LABEL: func @to_elements_of_scalar_broadcast_folds
+// CHECK-SAME: (%[[S:.*]]: f32) -> (f32, f32, f32, f32)
+func.func @to_elements_of_scalar_broadcast_folds(%s: f32) -> (f32, f32, f32, f32) {
+  %v = vector.broadcast %s : f32 to vector<4xf32>
+  %e:4 = vector.to_elements %v : vector<4xf32>
+  // CHECK-NOT: vector.broadcast
+  // CHECK-NOT: vector.to_elements
+  // CHECK: return %[[S]], %[[S]], %[[S]], %[[S]]
+  return %e#0, %e#1, %e#2, %e#3 : f32, f32, f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: func @to_elements_of_vector_broadcast
+// CHECK-SAME: (%[[VEC:.*]]: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32)
+func.func @to_elements_of_vector_broadcast(%vec: vector<2xf32>) -> (f32, f32, f32, f32, f32, f32) {
+  %v = vector.broadcast %vec : vector<2xf32> to vector<3x2xf32>
+  %e:6 = vector.to_elements %v : vector<3x2xf32>
+  // CHECK-NOT: vector.broadcast
+  // CHECK: %[[SRC_ELEMS:.*]]:2 = vector.to_elements %[[VEC]]
+  // CHECK: return %[[SRC_ELEMS]]#0, %[[SRC_ELEMS]]#1, %[[SRC_ELEMS]]#0, %[[SRC_ELEMS]]#1, %[[SRC_ELEMS]]#0, %[[SRC_ELEMS]]#1
+  return %e#0, %e#1, %e#2, %e#3, %e#4, %e#5 : f32, f32, f32, f32, f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: func @to_elements_of_vector_broadcast_inner_dim
+// CHECK-SAME: (%[[V:.*]]: vector<2x1x2xf32>) -> (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)
+func.func @to_elements_of_vector_broadcast_inner_dim(%v: vector<2x1x2xf32>) -> (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32) {
+  %b = vector.broadcast %v : vector<2x1x2xf32> to vector<2x3x2xf32>
+  %e:12 = vector.to_elements %b : vector<2x3x2xf32>
+  // CHECK-NOT: vector.broadcast
+  // CHECK: %[[SRC:.*]]:4 = vector.to_elements %[[V]] : vector<2x1x2xf32>
+  // CHECK: return %[[SRC]]#0, %[[SRC]]#1, %[[SRC]]#0, %[[SRC]]#1, %[[SRC]]#0, %[[SRC]]#1, %[[SRC]]#2, %[[SRC]]#3, %[[SRC]]#2, %[[SRC]]#3, %[[SRC]]#2, %[[SRC]]#3
+  return %e#0, %e#1, %e#2, %e#3, %e#4, %e#5, %e#6, %e#7, %e#8, %e#9, %e#10, %e#11 :
+    f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32
+}
+
+// -----
+
 // +---------------------------------------------------------------------------
 // Tests for foldFromElementsToConstant
 // +---------------------------------------------------------------------------


        


More information about the Mlir-commits mailing list