[Mlir-commits] [mlir] [mlir][Vector] Add `vector.shuffle` tree transformation (PR #145740)

Diego Caballero llvmlistbot at llvm.org
Mon Jun 30 22:54:39 PDT 2025


https://github.com/dcaballe updated https://github.com/llvm/llvm-project/pull/145740

>From 6aaa4e42756a10c95e59afaf2eba3e0b9583709c Mon Sep 17 00:00:00 2001
From: Diego Caballero <dcaballero at nvidia.com>
Date: Tue, 24 Jun 2025 06:33:22 +0000
Subject: [PATCH 1/2] [mlir][Vector] Add `vector.shuffle` tree transformation

This PR adds a new transformation that turns sequences of `vector.to_elements`
and `vector.from_elements` into a binary tree of `vector.shuffle` operations.
(Related RFC: https://discourse.llvm.org/t/rfc-adding-vector-to-elements-op-to-the-vector-dialect/86779).

Example:

```
  %0:4 = vector.to_elements %a : vector<4xf32>
  %1:4 = vector.to_elements %b : vector<4xf32>
  %2:4 = vector.to_elements %c : vector<4xf32>
  %3 = vector.from_elements %0#0, %0#1, %0#2, %0#3,
                            %1#0, %1#1, %1#2, %1#3,
                            %2#0, %2#1, %2#2, %2#3 : vector<12xf32>

==>

  %0 = vector.shuffle %a, %b [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
  %1 = vector.shuffle %c, %c [0, 1, 2, 3, -1, -1, -1, -1] : vector<4xf32>, vector<4xf32>
  %2 = vector.shuffle %0, %1 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32>
```

The algorithm leverages the structured extraction/insertion information of
`vector.to_elements` and `vector.from_elements` operations and builds a set
of intervals to determine the vector length that should be used at each level
of the tree.

There are a few improvements that can be implemented in the future, such as
shuffle mask compression to avoid unnecessarily large vector lengths with poison
values, but I decided to keep things "simpler" and spend more time documenting the
different steps of the algorithm so that people can follow along.
---
 .../Vector/Transforms/LoweringPatterns.h      |   7 +
 .../mlir/Dialect/Vector/Transforms/Passes.h   |   1 +
 .../mlir/Dialect/Vector/Transforms/Passes.td  |   5 +
 .../Dialect/Vector/Transforms/CMakeLists.txt  |   1 +
 ...LowerVectorToFromElementsToShuffleTree.cpp | 692 ++++++++++++++++++
 ...m-elements-to-shuffle-tree-transforms.mlir | 329 +++++++++
 6 files changed, 1035 insertions(+)
 create mode 100644 mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
 create mode 100644 mlir/test/Dialect/Vector/vector-tofrom-elements-to-shuffle-tree-transforms.mlir

diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 14cff4ff893b5..6761cd65c5009 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -297,6 +297,13 @@ void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns,
 /// n > 1.
 void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns);
 
+/// Populate patterns to rewrite sequences of `vector.to_elements` +
+/// `vector.from_elements` operations into a tree of `vector.shuffle`
+/// operations.
+void populateVectorToFromElementsToShuffleTreePatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit = 1);
+
 } // namespace vector
 } // namespace mlir
+
 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
index 5667f4fa95ace..959c2fbf31f1a 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
@@ -9,6 +9,7 @@
 #ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_
 #define MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_
 
+#include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 #include "mlir/Pass/Pass.h"
 
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
index 7436998749791..9431a4d8e240f 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
@@ -34,4 +34,9 @@ def LowerVectorMultiReduction : Pass<"lower-vector-multi-reduction", "func::Func
   ];
 }
 
+def LowerVectorToFromElementsToShuffleTree
+  : Pass<"lower-vector-to-from-elements-to-shuffle-tree", "func::FuncOp"> {
+  let summary = "Lower `vector.to_elements` and `vector.from_elements` to a tree of `vector.shuffle` operations";
+}
+
 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 8ca5cb6c6dfab..9e287fc109990 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
   LowerVectorScan.cpp
   LowerVectorShapeCast.cpp
   LowerVectorStep.cpp
+  LowerVectorToFromElementsToShuffleTree.cpp
   LowerVectorTransfer.cpp
   LowerVectorTranspose.cpp
   SubsetOpInterfaceImpl.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
new file mode 100644
index 0000000000000..53728d6dbe2a3
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
@@ -0,0 +1,692 @@
+//===- VectorShuffleTreeBuilder.cpp ----- Vector shuffle tree builder -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements pattern rewrites to lower sequences of
+// `vector.to_elements` and `vector.from_elements` operations into a tree of
+// `vector.shuffle` operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Transforms/Passes.h"
+#include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace mlir {
+namespace vector {
+
+#define GEN_PASS_DEF_LOWERVECTORTOFROMELEMENTSTOSHUFFLETREE
+#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
+
+} // namespace vector
+} // namespace mlir
+
+#define DEBUG_TYPE "lower-vector-to-from-elements-to-shuffle-tree"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+namespace {
+
+// Indentation unit for debug output formatting.
+constexpr unsigned kIndScale = 2;
+
+/// Represents a closed interval of elements (e.g., [0, 7] = 8 elements).
+using Interval = std::pair<unsigned, unsigned>;
+// Sentinel value for uninitialized intervals.
+constexpr unsigned kMaxUnsigned = std::numeric_limits<unsigned>::max();
+
+/// The VectorShuffleTreeBuilder builds a balanced binary tree of
+/// `vector.shuffle` operations from one or more `vector.to_elements`
+/// operations feeding a single `vector.from_elements` operation.
+///
+/// The implementation generates hardware-agnostic `vector.shuffle` operations
+/// that minimize both the number of shuffle operations and the length of
+/// intermediate vectors (to the extent possible). The tree has the
+/// following properties:
+///
+///   1. Vectors are shuffled in pairs by order of appearance in
+///      the `vector.from_elements` operand list.
+///   2. Each input vector to each level is used only once.
+///   3. The number of levels in the tree is:
+///        ceil(log2(# `vector.to_elements` ops)).
+///   4. Vectors at each level of the tree have the same vector length.
+///   5. Vector positions that do not need to be shuffled are represented with
+///      poison in the shuffle mask.
+///
+/// Examples #1: Concatenation of 3x vector<4xf32> to vector<12xf32>:
+///
+///   %0:4 = vector.to_elements %a : vector<4xf32>
+///   %1:4 = vector.to_elements %b : vector<4xf32>
+///   %2:4 = vector.to_elements %c : vector<4xf32>
+///   %3 = vector.from_elements %0#0, %0#1, %0#2, %0#3, %1#0, %1#1,
+///                             %1#2, %1#3, %2#0, %2#1, %2#2, %2#3
+///                               : vector<12xf32>
+///   =>
+///
+///   %shuffle0 = vector.shuffle %a, %b [0, 1, 2, 3, 4, 5, 6, 7]
+///     : vector<4xf32>, vector<4xf32>
+///   %shuffle1 = vector.shuffle %c, %c [0, 1, 2, 3, -1, -1, -1, -1]
+///     : vector<4xf32>, vector<4xf32>
+///   %result = vector.shuffle %shuffle0, %shuffle1 [0, 1, 2, 3, 4, 5,
+///                                                  6, 7, 8, 9, 10, 11]
+///     : vector<8xf32>, vector<8xf32>
+///
+///   Comments:
+///     * The shuffle tree has two levels:
+///         - Level 1 = (%shuffle0, %shuffle1)
+///         - Level 2 = (%result)
+///     * `%a` and `%b` are shuffled first because they appear first in the
+///       `vector.from_elements` operand list (`%0#0` and `%1#0`).
+///     * `%c` is shuffled with itself because the number of
+///       `vector.from_elements` operands is odd.
+///     * The vector length for the first and second levels are 8 and 16,
+///       respectively.
+///     * `%shuffle1` uses poison values to match the vector length of its
+///       tree level (8).
+///
+///
+/// Example #2: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+///   %0:5 = vector.to_elements %a : vector<5xf32>
+///   %1:5 = vector.to_elements %b : vector<5xf32>
+///   %2:5 = vector.to_elements %c : vector<5xf32>
+///   %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+///                             %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+///   =>
+///
+///   %shuffle0 = vector.shuffle %[[C]], %[[B]] [2, 6, -1, -1, 7, 2, 0, 6]
+///     : vector<5xf32>, vector<5xf32>
+///   %shuffle1 = vector.shuffle %[[A]], %[[A]] [1, 1, -1, -1, -1, -1, 4, -1]
+///     : vector<5xf32>, vector<5xf32>
+///   %result = vector.shuffle %shuffle0, %shuffle1 [0, 1, 8, 9, 4, 5, 6, 7, 14]
+///     : vector<8xf32>, vector<8xf32>
+///
+///   Comments:
+///     * `%c` and `%b` are shuffled first because they appear first in the
+///       `vector.from_elements` operand list (`%2#2` and `%1#1`).
+///     * `%a` is shuffled with itself because the number of
+///       `vector.from_elements` operands is odd.
+///     * The vector length for the first and second levels are 8 and 9,
+///       respectively.
+///     * `%shuffle0` uses poison values to mark unused vector positions and
+///       match the vector length of its tree level (8).
+///
+/// TODO: Implement mask compression to reduce the number of intermediate poison
+/// values.
+///
+class VectorShuffleTreeBuilder {
+public:
+  VectorShuffleTreeBuilder() = delete;
+  VectorShuffleTreeBuilder(FromElementsOp fromElemOp,
+                           ArrayRef<ToElementsOp> toElemDefs);
+
+  /// Analyze the input `vector.to_elements` + `vector.from_elements` sequence
+  /// and compute the shuffle tree configuration. This method does not generate
+  /// any IR.
+  LogicalResult computeShuffleTree();
+
+  /// Materialize the shuffle tree configuration computed by
+  /// `computeShuffleTree` in the IR.
+  Value generateShuffleTree(PatternRewriter &rewriter);
+
+private:
+  // IR input information.
+  FromElementsOp fromElementsOp;
+  SmallVector<ToElementsOp> toElementsDefs;
+
+  // Shuffle tree configuration.
+  unsigned numLevels;
+  SmallVector<unsigned> vectorSizePerLevel;
+  /// Holds the range of positions in the final output that each vector input
+  /// in the tree is contributing to.
+  SmallVector<SmallVector<Interval>> inputIntervalsPerLevel;
+
+  // Utility methods to compute the shuffle tree configuration.
+  void computeInputVectorIntervals();
+  void computeOutputVectorSizePerLevel();
+
+  /// Dump the shuffle tree configuration.
+  void dump();
+};
+
+VectorShuffleTreeBuilder::VectorShuffleTreeBuilder(
+    FromElementsOp fromElemOp, ArrayRef<ToElementsOp> toElemDefs)
+    : fromElementsOp(fromElemOp), toElementsDefs(toElemDefs) {
+
+  assert(fromElementsOp && "from_elements op is required");
+  assert(!toElementsDefs.empty() && "At least one to_elements op is required");
+
+  // Duplicate the last vector if the number of `vector.to_elements` is odd to
+  // simplify the shuffle tree algorithm.
+  if (toElementsDefs.size() % 2 != 0) {
+    toElementsDefs.push_back(toElementsDefs.back());
+  }
+}
+
+// ===--------------------------------------------------------------------===//
+// Shuffle Tree Analysis Utilities.
+// ===--------------------------------------------------------------------===//
+
+/// Compute the intervals for all the input vectors in the shuffle tree. The
+/// interval of an input vector is the range of positions in the final output
+/// that the input vector contributes to.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+///   %0:5 = vector.to_elements %a : vector<5xf32>
+///   %1:5 = vector.to_elements %b : vector<5xf32>
+///   %2:5 = vector.to_elements %c : vector<5xf32>
+///   %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+///                             %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+///
+/// Level 0 has 4 inputs (%2, %1, %0, %0, the last one is duplicated to make the
+/// number of inputs even) so we compute the interval for each input vector:
+///
+///    * inputIntervalsPerLevel[0][0] = interval(%2) = [0,6]
+///    * inputIntervalsPerLevel[0][1] = interval(%1) = [1,7]
+///    * inputIntervalsPerLevel[0][2] = interval(%0) = [2,8]
+///    * inputIntervalsPerLevel[0][3] = interval(%0) = [2,8]
+///
+/// Level 1 has 2 inputs, resulting from the shuffling of %2 + %1 and %0 + %0 so
+/// we compute the intervals for each input vector to level 1 as:
+///    * inputIntervalsPerLevel[1][0] = interval(%2) U interval(%1) = [0,7]
+///    * inputIntervalsPerLevel[1][1] = interval(%0) U interval(%0) = [2,8]
+///
+void VectorShuffleTreeBuilder::computeInputVectorIntervals() {
+  // Map `vector.to_elements` ops to their ordinal position in the
+  // `vector.from_elements` operand list. Make sure duplicated
+  // `vector.to_elements` ops are mapped to the its first occurrence.
+  DenseMap<ToElementsOp, unsigned> toElementsToInputOrdinal;
+  for (const auto &[idx, toElementsOp] : llvm::enumerate(toElementsDefs))
+    toElementsToInputOrdinal.insert({toElementsOp, idx});
+
+  // Compute intervals for each input vector in the shuffle tree. The first
+  // level computation is special-cased to keep the implementation simpler.
+
+  SmallVector<Interval> firstLevelIntervals(toElementsDefs.size(),
+                                            {kMaxUnsigned, kMaxUnsigned});
+
+  for (const auto &[idx, element] :
+       llvm::enumerate(fromElementsOp.getElements())) {
+    auto toElementsOp = cast<ToElementsOp>(element.getDefiningOp());
+    unsigned inputIdx = toElementsToInputOrdinal[toElementsOp];
+    Interval &currentInterval = firstLevelIntervals[inputIdx];
+
+    // Set lower bound to the first occurrence of the `vector.to_elements`.
+    if (currentInterval.first == kMaxUnsigned)
+      currentInterval.first = idx;
+
+    // Set upper bound to the last occurrence of the `vector.to_elements`.
+    currentInterval.second = idx;
+  }
+
+  // If the number of `vector.to_elements` is odd and the last op was
+  // duplicated, the interval for the duplicated op was not computed in the
+  // previous step as all the input occurrences were mapped to the original op.
+  // We copy the interval of the original op to the interval of the duplicated
+  // op manually.
+  if (firstLevelIntervals.back().second == kMaxUnsigned)
+    firstLevelIntervals.back() = *std::prev(firstLevelIntervals.end(), 2);
+
+  inputIntervalsPerLevel.push_back(std::move(firstLevelIntervals));
+
+  // Compute intervals for the remaining levels.
+  unsigned outputNumElements =
+      cast<VectorType>(fromElementsOp.getResult().getType()).getNumElements();
+  for (unsigned level = 1; level < numLevels; ++level) {
+    const auto &prevLevelIntervals = inputIntervalsPerLevel[level - 1];
+    SmallVector<Interval> currentLevelIntervals(
+        llvm::divideCeil(prevLevelIntervals.size(), 2),
+        {kMaxUnsigned, kMaxUnsigned});
+
+    for (size_t inputIdx = 0; inputIdx < currentLevelIntervals.size();
+         ++inputIdx) {
+      auto &interval = currentLevelIntervals[inputIdx];
+      const auto &prevLhsInterval = prevLevelIntervals[inputIdx * 2];
+      const auto &prevRhsInterval = prevLevelIntervals[inputIdx * 2 + 1];
+
+      // The interval of a vector at the current level is the union of the
+      // intervals of the two input vectors from the previous level being
+      // shuffled at this level.
+      interval.first = std::min(prevLhsInterval.first, prevRhsInterval.first);
+      interval.second =
+          std::min(std::max(prevLhsInterval.second, prevRhsInterval.second),
+                   outputNumElements - 1);
+    }
+
+    inputIntervalsPerLevel.push_back(std::move(currentLevelIntervals));
+  }
+}
+
+/// Compute the uniform output vector size for each level of the shuffle tree,
+/// given the intervals of the input vectors at that level. The output vector
+/// size of a level is the size of the widest interval resulting from shuffling
+/// each pair of input vectors.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+///   Intervals:
+///     * Level 0: [0,6], [1,7], [2,8], [2,8]
+///     * Level 1: [0,7], [2,8]
+///
+///   Vector sizes:
+///     * Level 0: max(size_of([0,6] U [1,7] = [0,7]) = 8,
+///                    size_of([2,8] U [2,8] = [2,8]) = 7) = 8
+///
+///     * Level 1: max(size_of([0,7] U [2,8] = [0,8]) = 9) = 9
+///
+void VectorShuffleTreeBuilder::computeOutputVectorSizePerLevel() {
+  // Compute vector size for each level.
+  for (unsigned level = 0; level < numLevels; ++level) {
+    const auto &currentLevelIntervals = inputIntervalsPerLevel[level];
+    unsigned currentVectorSize = 1;
+    for (size_t i = 0; i < currentLevelIntervals.size(); i += 2) {
+      const auto &lhsInterval = currentLevelIntervals[i];
+      const auto &rhsInterval = currentLevelIntervals[i + 1];
+      unsigned combinedIntervalSize =
+          std::max(lhsInterval.second, rhsInterval.second) - lhsInterval.first +
+          1;
+      currentVectorSize = std::max(currentVectorSize, combinedIntervalSize);
+    }
+    vectorSizePerLevel[level] = currentVectorSize;
+  }
+}
+
+void VectorShuffleTreeBuilder::dump() {
+  LLVM_DEBUG({
+    unsigned indLv = 0;
+
+    llvm::dbgs() << "VectorShuffleTreeBuilder Configuration:\n";
+    ++indLv;
+    llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Inputs:\n";
+    ++indLv;
+    for (const auto &toElementsOp : toElementsDefs)
+      llvm::dbgs() << llvm::indent(indLv, kIndScale) << toElementsOp << "\n";
+    llvm::dbgs() << llvm::indent(indLv, kIndScale) << fromElementsOp << "\n\n";
+    --indLv;
+
+    llvm::dbgs() << llvm::indent(indLv, kIndScale)
+                 << "* Total levels: " << numLevels << "\n";
+    llvm::dbgs() << llvm::indent(indLv, kIndScale)
+                 << "* Vector sizes per level: [";
+    llvm::interleaveComma(vectorSizePerLevel, llvm::dbgs());
+    llvm::dbgs() << "]\n";
+    llvm::dbgs() << llvm::indent(indLv, kIndScale)
+                 << "* Input intervals per level:\n";
+    ++indLv;
+    for (const auto &[level, intervals] :
+         llvm::enumerate(inputIntervalsPerLevel)) {
+      llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Level " << level
+                   << ": ";
+      llvm::interleaveComma(intervals, llvm::dbgs(),
+                            [](const Interval &interval) {
+                              llvm::dbgs() << "[" << interval.first << ","
+                                           << interval.second << "]";
+                            });
+      llvm::dbgs() << "\n";
+    }
+  });
+}
+
+/// Compute the shuffle tree configuration for the given `vector.to_elements` +
+/// `vector.from_elements` input sequence. This method builds a balanced binary
+/// shuffle tree that combines pairs of input vectors at each level.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+///   %0:5 = vector.to_elements %a : vector<5xf32>
+///   %1:5 = vector.to_elements %b : vector<5xf32>
+///   %2:5 = vector.to_elements %c : vector<5xf32>
+///   %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+///                             %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+///
+///   build a tree that looks like:
+///
+///          %2    %1                   %0    %0
+///            \  /                       \  /
+///  %2_1 = vector.shuffle     %0_0 = vector.shuffle
+///              \                    /
+///             %2_1_0_0 =vector.shuffle
+///
+/// The configuration comprises of computing the intervals of the input vectors
+/// at each level of the shuffle tree (i.e., %2, %1, %0, %0, %2_1, %0_0 and
+/// %2_1_0_0) and the output vector size for each level. For further details on
+/// intervals and output vector size computation, please, take a look at the
+/// corresponding utility functions.
+LogicalResult VectorShuffleTreeBuilder::computeShuffleTree() {
+  // Initialize shuffle tree information based on its size.
+  assert(toElementsDefs.size() > 1 &&
+         "At least two 'vector.to_elements' ops are required");
+  numLevels = llvm::Log2_64(toElementsDefs.size());
+  vectorSizePerLevel.resize(numLevels, 0);
+  inputIntervalsPerLevel.reserve(numLevels);
+
+  computeInputVectorIntervals();
+  computeOutputVectorSizePerLevel();
+  dump();
+
+  return success();
+}
+
+// ===--------------------------------------------------------------------===//
+// Shuffle Tree Code Generation Utilities.
+// ===--------------------------------------------------------------------===//
+
+/// Compute the permutation mask for shuffling two input `vector.to_elements`
+/// ops. The permutation mask is the mapping of the input vector elements to
+/// their final position in the output vector, relative to the intermediate
+/// output vector of the `vector.shuffle` operation combining the two inputs.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+///   %0:5 = vector.to_elements %a : vector<5xf32>
+///   %1:5 = vector.to_elements %b : vector<5xf32>
+///   %2:5 = vector.to_elements %c : vector<5xf32>
+///   %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+///                             %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+///
+///   =>
+///
+///   // Level 0, vector length = 8
+///   %2_1 = PermutationShuffleMask(%2, %1) = [2, 6, -1, -1, 7, 2, 0, 6]
+///   %0_0 = PermutationShuffleMask(%0, %0) = [1, 1, -1, -1, -1, -1, 4, -1]
+///
+/// TODO: Implement mask compression.
+static SmallVector<int64_t> computePermutationShuffleMask(
+    ToElementsOp toElementOp0, const Interval &interval0,
+    ToElementsOp toElementOp1, const Interval &interval1,
+    FromElementsOp fromElementsOp, unsigned outputVectorSize) {
+  SmallVector<int64_t> mask(outputVectorSize, ShuffleOp::kPoisonIndex);
+  unsigned inputVectorSize =
+      toElementOp0.getSource().getType().getNumElements();
+
+  for (const auto &[inputIdx, element] :
+       llvm::enumerate(fromElementsOp.getElements())) {
+    auto currentToElemOp = cast<ToElementsOp>(element.getDefiningOp());
+    // Match `vector.from_elements` operands to the two input ops.
+    if (currentToElemOp != toElementOp0 && currentToElemOp != toElementOp1)
+      continue;
+
+    // The permutation value for a particular operand is the ordinal position of
+    // the operand in the `vector.to_elements` list of results.
+    unsigned permVal = cast<OpResult>(element).getResultNumber();
+    unsigned maskIdx = inputIdx;
+
+    // The mask index is the ordinal position of the operand in
+    // `vector.from_elements` operand list. We make this position relative to
+    // the interval of the output vector resulting from combining the two
+    // input vectors.
+    if (currentToElemOp == toElementOp0) {
+      maskIdx -= interval0.first;
+    } else {
+      // currentToElemOp == toElementOp1
+      unsigned intervalOffset = interval1.first - interval0.first;
+      maskIdx += intervalOffset - interval1.first;
+      permVal += inputVectorSize;
+    }
+
+    mask[maskIdx] = permVal;
+  }
+
+  LLVM_DEBUG({
+    unsigned indLv = 1;
+    llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Permutation mask: [";
+    llvm::interleaveComma(mask, llvm::dbgs());
+    llvm::dbgs() << "]\n";
+    ++indLv;
+    llvm::dbgs() << llvm::indent(indLv, kIndScale)
+                 << "* Combining: " << toElementOp0 << " and " << toElementOp1
+                 << "\n";
+  });
+
+  return mask;
+}
+
+/// Compute the propagation shuffle mask for combining two intermediate shuffle
+/// operations of the tree. The propagation shuffle mask is the mapping of the
+/// intermediate vector elements, which have already been shuffled to their
+/// relative output position using the mask generated by
+/// `computePermutationShuffleMask`, to their next position in the tree.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+///   %0:5 = vector.to_elements %a : vector<5xf32>
+///   %1:5 = vector.to_elements %b : vector<5xf32>
+///   %2:5 = vector.to_elements %c : vector<5xf32>
+///   %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+///                             %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+///
+///   // Level 0, vector length = 8
+///   %2_1 = PermutationShuffleMask(%2, %1) = [2, 6, -1, -1, 7, 2, 0, 6]
+///   %0_0 = PermutationShuffleMask(%0, %0) = [1, 1, -1, -1, -1, -1, 4, -1]
+///
+///   =>
+///
+///   // Level 1, vector length = 9
+///   PropagationShuffleMask(%2_1, %0_0) = [0, 1, 8, 9, 4, 5, 6, 7, 14]
+///
+/// TODO: Implement mask compression.
+///
+static SmallVector<int64_t> computePropagationShuffleMask(
+    ShuffleOp lhsShuffleOp, const Interval &lhsInterval, ShuffleOp rhsShuffleOp,
+    const Interval &rhsInterval, unsigned outputVectorSize) {
+  ArrayRef<int64_t> lhsShuffleMask = lhsShuffleOp.getMask();
+  ArrayRef<int64_t> rhsShuffleMask = rhsShuffleOp.getMask();
+  unsigned inputVectorSize = lhsShuffleMask.size();
+  assert(inputVectorSize == rhsShuffleMask.size() &&
+         "Expected both shuffle masks to have the same size");
+
+  unsigned lhsRhsOffset = rhsInterval.first - lhsInterval.first;
+  SmallVector<int64_t> mask(outputVectorSize, ShuffleOp::kPoisonIndex);
+
+  // Propagate any element from the input mask that is not poison. For the RHS
+  // input vector, the mask index is offset by the offset between the two
+  // intervals of the input vectors.
+  for (unsigned i = 0; i < inputVectorSize; ++i) {
+    if (lhsShuffleMask[i] != ShuffleOp::kPoisonIndex)
+      mask[i] = i;
+
+    unsigned rhsIdx = i + lhsRhsOffset;
+    if (rhsShuffleMask[i] != ShuffleOp::kPoisonIndex) {
+      assert(rhsIdx < outputVectorSize && "RHS index out of bounds");
+      assert(mask[rhsIdx] == ShuffleOp::kPoisonIndex && "mask already set");
+      mask[rhsIdx] = i + inputVectorSize;
+    }
+  }
+
+  LLVM_DEBUG({
+    unsigned indLv = 1;
+    llvm::dbgs() << llvm::indent(indLv, kIndScale)
+                 << "* Propagation shuffle mask computation:\n";
+    ++indLv;
+    llvm::dbgs() << llvm::indent(indLv, kIndScale)
+                 << "* LHS shuffle op: " << lhsShuffleOp << "\n";
+    llvm::dbgs() << llvm::indent(indLv, kIndScale)
+                 << "* RHS shuffle op: " << rhsShuffleOp << "\n";
+    llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Result mask: [";
+    llvm::interleaveComma(mask, llvm::dbgs());
+    llvm::dbgs() << "]\n";
+  });
+
+  return mask;
+}
+
+/// Materialize the pre-computed shuffle tree configuration in the IR by
+/// generating the corresponding `vector.shuffle` ops.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+///   %0:5 = vector.to_elements %a : vector<5xf32>
+///   %1:5 = vector.to_elements %b : vector<5xf32>
+///   %2:5 = vector.to_elements %c : vector<5xf32>
+///   %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+///                             %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+///
+///   with the pre-computed shuffle tree configuration:
+///
+///     * Vector sizes per level: [8, 9]
+///     * Input intervals per level:
+///       * Level 0: [0,6], [1,7], [2,8], [2,8]
+///       * Level 1: [0,7], [2,8]
+///
+///   =>
+///
+///    %0 = vector.shuffle %arg2, %arg1 [2, 6, -1, -1, 7, 2, 0, 6]
+///        : vector<5xf32>, vector<5xf32>
+///    %1 = vector.shuffle %arg0, %arg0 [1, 1, -1, -1, -1, -1, 4, -1]
+///        : vector<5xf32>, vector<5xf32>
+///    %2 = vector.shuffle %0, %1 [0, 1, 8, 9, 4, 5, 6, 7, 14]
+///        : vector<8xf32>, vector<8xf32>
+///
+/// The code generation comprises of combining pairs of input vectors for each
+/// level of the tree, using the pre-computed per tree level intervals and
+/// vector sizes. The algorithm generates two kinds of shuffle masks:
+/// permutation masks and propagation masks. Permutation masks are computed for
+/// the first level of the tree and permute the input vector elements to their
+/// relative position in the final output. Propagation masks are computed for
+/// subsequent levels and propagate the elements to the next level without
+/// permutation. For further details on the shuffle mask computation, please,
+/// take a look at the corresponding `computePermutationShuffleMask` and
+/// `computePropagationShuffleMask` functions.
+///
+Value VectorShuffleTreeBuilder::generateShuffleTree(PatternRewriter &rewriter) {
+  LLVM_DEBUG(llvm::dbgs() << "VectorShuffleTreeBuilder Code Generation:\n");
+
+  // Initialize work list with the `vector.to_elements` sources.
+  SmallVector<Value> levelInputs;
+  llvm::transform(
+      toElementsDefs, std::back_inserter(levelInputs),
+      [](ToElementsOp toElementsOp) { return toElementsOp.getSource(); });
+
+  // Build shuffle tree by combining pairs of vectors.
+  Location loc = fromElementsOp.getLoc();
+  unsigned currentLevel = 0;
+  for (const auto &[levelVectorSize, inputIntervals] :
+       llvm::zip_equal(vectorSizePerLevel, inputIntervalsPerLevel)) {
+    LLVM_DEBUG(llvm::dbgs()
+               << llvm::indent(1, kIndScale) << "* Processing level "
+               << currentLevel << " (vector size: " << levelVectorSize
+               << ", # inputs: " << levelInputs.size() << ")\n");
+
+    // Process level input vectors in pairs.
+    SmallVector<Value> levelOutputs;
+    for (size_t i = 0; i < levelInputs.size(); i += 2) {
+      Value lhsVector = levelInputs[i];
+      Value rhsVector = levelInputs[i + 1];
+      const Interval &lhsInterval = inputIntervals[i];
+      const Interval &rhsInterval = inputIntervals[i + 1];
+
+      // For the first level of the tree, permute the vector elements to their
+      // relative position in the final output. For subsequent levels, we
+      // propagate the elements to the next level without permutation.
+      SmallVector<int64_t> shuffleMask;
+      if (currentLevel == 0) {
+        shuffleMask = computePermutationShuffleMask(
+            toElementsDefs[i], lhsInterval, toElementsDefs[i + 1], rhsInterval,
+            fromElementsOp, levelVectorSize);
+      } else {
+        auto lhsShuffleOp = cast<ShuffleOp>(lhsVector.getDefiningOp());
+        auto rhsShuffleOp = cast<ShuffleOp>(rhsVector.getDefiningOp());
+        shuffleMask = computePropagationShuffleMask(lhsShuffleOp, lhsInterval,
+                                                    rhsShuffleOp, rhsInterval,
+                                                    levelVectorSize);
+      }
+
+      Value shuffleVal = rewriter.create<vector::ShuffleOp>(
+          loc, lhsVector, rhsVector, shuffleMask);
+      levelOutputs.push_back(shuffleVal);
+    }
+
+    levelInputs = std::move(levelOutputs);
+    ++currentLevel;
+  }
+
+  assert(levelInputs.size() == 1 && "Should have exactly one result");
+  return levelInputs.front();
+}
+
+/// Gather and unique all the `vector.to_elements` operations that feed the
+/// `vector.from_elements` operation. The `vector.to_elements` operations are
+/// returned in order of appearance in the `vector.from_elements`'s operand
+/// list.
+static LogicalResult
+getToElementsDefiningOps(FromElementsOp fromElementsOp,
+                         SmallVectorImpl<ToElementsOp> &toElementsDefs) {
+  SetVector<ToElementsOp> toElementsDefsSet;
+  for (Value element : fromElementsOp.getElements()) {
+    auto toElementsOp = element.getDefiningOp<ToElementsOp>();
+    if (!toElementsOp)
+      return failure();
+    toElementsDefsSet.insert(toElementsOp);
+  }
+
+  toElementsDefs.assign(toElementsDefsSet.begin(), toElementsDefsSet.end());
+  return success();
+}
+
+/// Pass to rewrite `vector.to_elements` + `vector.from_elements` sequences into
+/// a tree of `vector.shuffle` operations.
+struct ToFromElementsToShuffleTreeRewrite final
+    : OpRewritePattern<vector::FromElementsOp> {
+
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::FromElementsOp fromElementsOp,
+                                PatternRewriter &rewriter) const override {
+    VectorType resultType = fromElementsOp.getType();
+    if (resultType.getRank() != 1 || resultType.isScalable())
+      return failure();
+
+    SmallVector<ToElementsOp> toElementsDefs;
+    if (failed(getToElementsDefiningOps(fromElementsOp, toElementsDefs)))
+      return failure();
+
+    // Avoid generating a shuffle tree for trivial `vector.to_elements` ->
+    // `vector.from_elements` forwarding cases that do not require shuffling.
+    if (toElementsDefs.size() == 1) {
+      ToElementsOp toElementsOp0 = toElementsDefs.front();
+      if (llvm::equal(fromElementsOp.getElements(), toElementsOp0.getResults()))
+        return failure();
+    }
+
+    VectorShuffleTreeBuilder shuffleTreeBuilder(fromElementsOp, toElementsDefs);
+    if (failed(shuffleTreeBuilder.computeShuffleTree()))
+      return failure();
+
+    Value finalShuffle = shuffleTreeBuilder.generateShuffleTree(rewriter);
+    rewriter.replaceOp(fromElementsOp, finalShuffle);
+    return success();
+  }
+};
+
+struct LowerVectorToFromElementsToShuffleTreePass
+    : public vector::impl::LowerVectorToFromElementsToShuffleTreeBase<
+          LowerVectorToFromElementsToShuffleTreePass> {
+
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateVectorToFromElementsToShuffleTreePatterns(patterns);
+
+    if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+      return signalPassFailure();
+  }
+};
+
+} // namespace
+
+void mlir::vector::populateVectorToFromElementsToShuffleTreePatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<ToFromElementsToShuffleTreeRewrite>(patterns.getContext(),
+                                                   benefit);
+}
diff --git a/mlir/test/Dialect/Vector/vector-tofrom-elements-to-shuffle-tree-transforms.mlir b/mlir/test/Dialect/Vector/vector-tofrom-elements-to-shuffle-tree-transforms.mlir
new file mode 100644
index 0000000000000..3dc579be12f0f
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-tofrom-elements-to-shuffle-tree-transforms.mlir
@@ -0,0 +1,329 @@
+// RUN: mlir-opt -lower-vector-to-from-elements-to-shuffle-tree -split-input-file %s | FileCheck %s
+
+// Captured variable names for `vector.shuffle` operations follow the L#SH# convention,
+// where L# refers to the level of the tree the shuffle belongs to, and SH# refers to
+// the shuffle index within that level.
+
+func.func @to_from_elements_single_input_shuffle(%a: vector<8xf32>) -> vector<8xf32> {
+  %0:8 = vector.to_elements %a : vector<8xf32>
+  %1 = vector.from_elements %0#7, %0#0, %0#6, %0#1, %0#5, %0#2, %0#4, %0#3 : vector<8xf32>
+  return %1 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @to_from_elements_single_input_shuffle(
+//  CHECK-SAME:     %[[A:.*]]: vector<8xf32>
+      // CHECK:   %[[L0SH0:.*]] = vector.shuffle %[[A]], %[[A]] [7, 0, 6, 1, 5, 2, 4, 3] : vector<8xf32>, vector<8xf32>
+      // CHECK:   return %[[L0SH0]]
+
+// -----
+
+func.func @from_elements_to_elements_single_shuffle(%a: vector<8xf32>,
+                                                    %b: vector<8xf32>) -> vector<8xf32> {
+  %0:8 = vector.to_elements %a : vector<8xf32>
+  %1:8 = vector.to_elements %b : vector<8xf32>
+  %2 = vector.from_elements %0#7, %1#0, %0#6, %1#1, %0#5, %1#2, %0#4, %1#3 : vector<8xf32>
+  return %2 : vector<8xf32>
+}
+
+// CHECK-LABEL: func @from_elements_to_elements_single_shuffle(
+//  CHECK-SAME:     %[[A:.*]]: vector<8xf32>, %[[B:.*]]: vector<8xf32>
+//       CHECK:   %[[L0SH0:.*]] = vector.shuffle %[[A]], %[[B]] [7, 8, 6, 9, 5, 10, 4, 11] : vector<8xf32>
+//       CHECK:   return %[[L0SH0]]
+
+// -----
+
+func.func @to_from_elements_shuffle_tree_concat_4x8_to_32(%a: vector<8xf32>,
+                                                          %b: vector<8xf32>,
+                                                          %c: vector<8xf32>,
+                                                          %d: vector<8xf32>) -> vector<32xf32> {
+  %0:8 = vector.to_elements %a : vector<8xf32>
+  %1:8 = vector.to_elements %b : vector<8xf32>
+  %2:8 = vector.to_elements %c : vector<8xf32>
+  %3:8 = vector.to_elements %d : vector<8xf32>
+  %4 = vector.from_elements %0#0, %0#1, %0#2, %0#3, %0#4, %0#5, %0#6, %0#7,
+                            %1#0, %1#1, %1#2, %1#3, %1#4, %1#5, %1#6, %1#7,
+                            %2#0, %2#1, %2#2, %2#3, %2#4, %2#5, %2#6, %2#7,
+                            %3#0, %3#1, %3#2, %3#3, %3#4, %3#5, %3#6, %3#7 : vector<32xf32>
+  return %4 : vector<32xf32>
+}
+
+// CHECK-LABEL: func @to_from_elements_shuffle_tree_concat_4x8_to_32(
+//  CHECK-SAME:     %[[A:.*]]: vector<8xf32>, %[[B:.*]]: vector<8xf32>, %[[C:.*]]: vector<8xf32>, %[[D:.*]]: vector<8xf32>
+//       CHECK:   %[[L0SH0:.*]] = vector.shuffle %[[A]], %[[B]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
+//       CHECK:   %[[L0SH1:.*]] = vector.shuffle %[[C]], %[[D]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
+//       CHECK:   %[[L1SH0:.*]] = vector.shuffle %[[L0SH0]], %[[L0SH1]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+//       CHECK:   return %[[L1SH0]] : vector<32xf32>
+
+// -----
+
+func.func @to_from_elements_shuffle_tree_concat_3x4_to_12(%a: vector<4xf32>,
+                                                          %b: vector<4xf32>,
+                                                          %c: vector<4xf32>) -> vector<12xf32> {
+  %0:4 = vector.to_elements %a : vector<4xf32>
+  %1:4 = vector.to_elements %b : vector<4xf32>
+  %2:4 = vector.to_elements %c : vector<4xf32>
+  %3 = vector.from_elements %0#0, %0#1, %0#2, %0#3, %1#0, %1#1, %1#2, %1#3, %2#0, %2#1, %2#2, %2#3 : vector<12xf32>
+  return %3 : vector<12xf32>
+}
+
+// CHECK-LABEL: func @to_from_elements_shuffle_tree_concat_3x4_to_12(
+//  CHECK-SAME:     %[[A:.*]]: vector<4xf32>, %[[B:.*]]: vector<4xf32>, %[[C:.*]]: vector<4xf32>
+//       CHECK:   %[[L0SH0:.*]] = vector.shuffle %[[A]], %[[B]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+//       CHECK:   %[[L0SH1:.*]] = vector.shuffle %[[C]], %[[C]] [0, 1, 2, 3, -1, -1, -1, -1] : vector<4xf32>, vector<4xf32>
+//       CHECK:   %[[L1SH0:.*]] = vector.shuffle %[[L0SH0]], %[[L0SH1]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32>
+//       CHECK:   return %[[L1SH0]] : vector<12xf32>
+
+// -----
+
+func.func @to_from_elements_shuffle_tree_concat_64x4_256(
+  %a: vector<4xf32>, %b: vector<4xf32>, %c: vector<4xf32>, %d: vector<4xf32>,
+  %e: vector<4xf32>, %f: vector<4xf32>, %g: vector<4xf32>, %h: vector<4xf32>,
+  %i: vector<4xf32>, %j: vector<4xf32>, %k: vector<4xf32>, %l: vector<4xf32>,
+  %m: vector<4xf32>, %n: vector<4xf32>, %o: vector<4xf32>, %p: vector<4xf32>,
+  %q: vector<4xf32>, %r: vector<4xf32>, %s: vector<4xf32>, %t: vector<4xf32>,
+  %u: vector<4xf32>, %v: vector<4xf32>, %w: vector<4xf32>, %x: vector<4xf32>,
+  %y: vector<4xf32>, %z: vector<4xf32>, %aa: vector<4xf32>, %ab: vector<4xf32>,
+  %ac: vector<4xf32>, %ad: vector<4xf32>, %ae: vector<4xf32>, %af: vector<4xf32>,
+  %ag: vector<4xf32>, %ah: vector<4xf32>, %ai: vector<4xf32>, %aj: vector<4xf32>,
+  %ak: vector<4xf32>, %al: vector<4xf32>, %am: vector<4xf32>, %an: vector<4xf32>,
+  %ao: vector<4xf32>, %ap: vector<4xf32>, %aq: vector<4xf32>, %ar: vector<4xf32>,
+  %as: vector<4xf32>, %at: vector<4xf32>, %au: vector<4xf32>, %av: vector<4xf32>,
+  %aw: vector<4xf32>, %ax: vector<4xf32>, %ay: vector<4xf32>, %az: vector<4xf32>,
+  %ba: vector<4xf32>, %bb: vector<4xf32>, %bc: vector<4xf32>, %bd: vector<4xf32>,
+  %be: vector<4xf32>, %bf: vector<4xf32>, %bg: vector<4xf32>, %bh: vector<4xf32>,
+  %bi: vector<4xf32>, %bj: vector<4xf32>, %bk: vector<4xf32>, %bl: vector<4xf32>) -> vector<256xf32> {
+  %0:4 = vector.to_elements %a : vector<4xf32>
+  %1:4 = vector.to_elements %b : vector<4xf32>
+  %2:4 = vector.to_elements %c : vector<4xf32>
+  %3:4 = vector.to_elements %d : vector<4xf32>
+  %4:4 = vector.to_elements %e : vector<4xf32>
+  %5:4 = vector.to_elements %f : vector<4xf32>
+  %6:4 = vector.to_elements %g : vector<4xf32>
+  %7:4 = vector.to_elements %h : vector<4xf32>
+  %8:4 = vector.to_elements %i : vector<4xf32>
+  %9:4 = vector.to_elements %j : vector<4xf32>
+  %10:4 = vector.to_elements %k : vector<4xf32>
+  %11:4 = vector.to_elements %l : vector<4xf32>
+  %12:4 = vector.to_elements %m : vector<4xf32>
+  %13:4 = vector.to_elements %n : vector<4xf32>
+  %14:4 = vector.to_elements %o : vector<4xf32>
+  %15:4 = vector.to_elements %p : vector<4xf32>
+  %16:4 = vector.to_elements %q : vector<4xf32>
+  %17:4 = vector.to_elements %r : vector<4xf32>
+  %18:4 = vector.to_elements %s : vector<4xf32>
+  %19:4 = vector.to_elements %t : vector<4xf32>
+  %20:4 = vector.to_elements %u : vector<4xf32>
+  %21:4 = vector.to_elements %v : vector<4xf32>
+  %22:4 = vector.to_elements %w : vector<4xf32>
+  %23:4 = vector.to_elements %x : vector<4xf32>
+  %24:4 = vector.to_elements %y : vector<4xf32>
+  %25:4 = vector.to_elements %z : vector<4xf32>
+  %26:4 = vector.to_elements %aa : vector<4xf32>
+  %27:4 = vector.to_elements %ab : vector<4xf32>
+  %28:4 = vector.to_elements %ac : vector<4xf32>
+  %29:4 = vector.to_elements %ad : vector<4xf32>
+  %30:4 = vector.to_elements %ae : vector<4xf32>
+  %31:4 = vector.to_elements %af : vector<4xf32>
+  %32:4 = vector.to_elements %ag : vector<4xf32>
+  %33:4 = vector.to_elements %ah : vector<4xf32>
+  %34:4 = vector.to_elements %ai : vector<4xf32>
+  %35:4 = vector.to_elements %aj : vector<4xf32>
+  %36:4 = vector.to_elements %ak : vector<4xf32>
+  %37:4 = vector.to_elements %al : vector<4xf32>
+  %38:4 = vector.to_elements %am : vector<4xf32>
+  %39:4 = vector.to_elements %an : vector<4xf32>
+  %40:4 = vector.to_elements %ao : vector<4xf32>
+  %41:4 = vector.to_elements %ap : vector<4xf32>
+  %42:4 = vector.to_elements %aq : vector<4xf32>
+  %43:4 = vector.to_elements %ar : vector<4xf32>
+  %44:4 = vector.to_elements %as : vector<4xf32>
+  %45:4 = vector.to_elements %at : vector<4xf32>
+  %46:4 = vector.to_elements %au : vector<4xf32>
+  %47:4 = vector.to_elements %av : vector<4xf32>
+  %48:4 = vector.to_elements %aw : vector<4xf32>
+  %49:4 = vector.to_elements %ax : vector<4xf32>
+  %50:4 = vector.to_elements %ay : vector<4xf32>
+  %51:4 = vector.to_elements %az : vector<4xf32>
+  %52:4 = vector.to_elements %ba : vector<4xf32>
+  %53:4 = vector.to_elements %bb : vector<4xf32>
+  %54:4 = vector.to_elements %bc : vector<4xf32>
+  %55:4 = vector.to_elements %bd : vector<4xf32>
+  %56:4 = vector.to_elements %be : vector<4xf32>
+  %57:4 = vector.to_elements %bf : vector<4xf32>
+  %58:4 = vector.to_elements %bg : vector<4xf32>
+  %59:4 = vector.to_elements %bh : vector<4xf32>
+  %60:4 = vector.to_elements %bi : vector<4xf32>
+  %61:4 = vector.to_elements %bj : vector<4xf32>
+  %62:4 = vector.to_elements %bk : vector<4xf32>
+  %63:4 = vector.to_elements %bl : vector<4xf32>
+  %64 = vector.from_elements %0#0, %0#1, %0#2, %0#3, %1#0, %1#1, %1#2, %1#3, %2#0, %2#1, %2#2, %2#3, %3#0, %3#1, %3#2, %3#3, %4#0, %4#1, %4#2, %4#3,
+                             %5#0, %5#1, %5#2, %5#3, %6#0, %6#1, %6#2, %6#3, %7#0, %7#1, %7#2, %7#3, %8#0, %8#1, %8#2, %8#3, %9#0, %9#1, %9#2, %9#3,
+                             %10#0, %10#1, %10#2, %10#3, %11#0, %11#1, %11#2, %11#3, %12#0, %12#1, %12#2, %12#3, %13#0, %13#1, %13#2, %13#3, %14#0, %14#1, %14#2, %14#3,
+                             %15#0, %15#1, %15#2, %15#3, %16#0, %16#1, %16#2, %16#3, %17#0, %17#1, %17#2, %17#3, %18#0, %18#1, %18#2, %18#3, %19#0, %19#1, %19#2, %19#3,
+                             %20#0, %20#1, %20#2, %20#3, %21#0, %21#1, %21#2, %21#3, %22#0, %22#1, %22#2, %22#3, %23#0, %23#1, %23#2, %23#3, %24#0, %24#1, %24#2, %24#3,
+                             %25#0, %25#1, %25#2, %25#3, %26#0, %26#1, %26#2, %26#3, %27#0, %27#1, %27#2, %27#3, %28#0, %28#1, %28#2, %28#3, %29#0, %29#1, %29#2, %29#3,
+                             %30#0, %30#1, %30#2, %30#3, %31#0, %31#1, %31#2, %31#3, %32#0, %32#1, %32#2, %32#3, %33#0, %33#1, %33#2, %33#3, %34#0, %34#1, %34#2, %34#3,
+                             %35#0, %35#1, %35#2, %35#3, %36#0, %36#1, %36#2, %36#3, %37#0, %37#1, %37#2, %37#3, %38#0, %38#1, %38#2, %38#3, %39#0, %39#1, %39#2, %39#3,
+                             %40#0, %40#1, %40#2, %40#3, %41#0, %41#1, %41#2, %41#3, %42#0, %42#1, %42#2, %42#3, %43#0, %43#1, %43#2, %43#3, %44#0, %44#1, %44#2, %44#3,
+                             %45#0, %45#1, %45#2, %45#3, %46#0, %46#1, %46#2, %46#3, %47#0, %47#1, %47#2, %47#3, %48#0, %48#1, %48#2, %48#3, %49#0, %49#1, %49#2, %49#3,
+                             %50#0, %50#1, %50#2, %50#3, %51#0, %51#1, %51#2, %51#3, %52#0, %52#1, %52#2, %52#3, %53#0, %53#1, %53#2, %53#3, %54#0, %54#1, %54#2, %54#3,
+                             %55#0, %55#1, %55#2, %55#3, %56#0, %56#1, %56#2, %56#3, %57#0, %57#1, %57#2, %57#3, %58#0, %58#1, %58#2, %58#3, %59#0, %59#1, %59#2, %59#3,
+                             %60#0, %60#1, %60#2, %60#3, %61#0, %61#1, %61#2, %61#3, %62#0, %62#1, %62#2, %62#3, %63#0, %63#1, %63#2, %63#3 : vector<256xf32>
+  return %64 : vector<256xf32>
+}
+
+// CHECK-LABEL: func.func @to_from_elements_shuffle_tree_concat_64x4_256(
+//  CHECK-SAME:     %[[A:.+]]: vector<4xf32>, %[[B:.+]]: vector<4xf32>, %[[C:.+]]: vector<4xf32>, %[[D:.+]]: vector<4xf32>, %[[E:.+]]: vector<4xf32>, %[[F:.+]]: vector<4xf32>, %[[G:.+]]: vector<4xf32>, %[[H:.+]]: vector<4xf32>, %[[I:.+]]: vector<4xf32>, %[[J:.+]]: vector<4xf32>, %[[K:.+]]: vector<4xf32>, %[[L:.+]]: vector<4xf32>, %[[M:.+]]: vector<4xf32>, %[[N:.+]]: vector<4xf32>, %[[O:.+]]: vector<4xf32>, %[[P:.+]]: vector<4xf32>, %[[Q:.+]]: vector<4xf32>, %[[R:.+]]: vector<4xf32>, %[[S:.+]]: vector<4xf32>, %[[T:.+]]: vector<4xf32>, %[[U:.+]]: vector<4xf32>, %[[V:.+]]: vector<4xf32>, %[[W:.+]]: vector<4xf32>, %[[X:.+]]: vector<4xf32>, %[[Y:.+]]: vector<4xf32>, %[[Z:.+]]: vector<4xf32>, %[[AA:.+]]: vector<4xf32>, %[[AB:.+]]: vector<4xf32>, %[[AC:.+]]: vector<4xf32>, %[[AD:.+]]: vector<4xf32>, %[[AE:.+]]: vector<4xf32>, %[[AF:.+]]: vector<4xf32>, %[[AG:.+]]: vector<4xf32>, %[[AH:.+]]: vector<4xf32>, %[[AI:.+]]: vector<4xf32>, %[[AJ:.+]]: vector<4xf32>, %[[AK:.+]]: vector<4xf32>, %[[AL:.+]]: vector<4xf32>, %[[AM:.+]]: vector<4xf32>, %[[AN:.+]]: vector<4xf32>, %[[AO:.+]]: vector<4xf32>, %[[AP:.+]]: vector<4xf32>, %[[AQ:.+]]: vector<4xf32>, %[[AR:.+]]: vector<4xf32>, %[[AS:.+]]: vector<4xf32>, %[[AT:.+]]: vector<4xf32>, %[[AU:.+]]: vector<4xf32>, %[[AV:.+]]: vector<4xf32>, %[[AW:.+]]: vector<4xf32>, %[[AX:.+]]: vector<4xf32>, %[[AY:.+]]: vector<4xf32>, %[[AZ:.+]]: vector<4xf32>, %[[BA:.+]]: vector<4xf32>, %[[BB:.+]]: vector<4xf32>, %[[BC:.+]]: vector<4xf32>, %[[BD:.+]]: vector<4xf32>, %[[BE:.+]]: vector<4xf32>, %[[BF:.+]]: vector<4xf32>, %[[BG:.+]]: vector<4xf32>, %[[BH:.+]]: vector<4xf32>, %[[BI:.+]]: vector<4xf32>, %[[BJ:.+]]: vector<4xf32>, %[[BK:.+]]: vector<4xf32>, %[[BL:.+]]: vector<4xf32>)
+//       CHECK:   %[[L0SH0:.+]] = vector.shuffle %[[A]], %[[B]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+//       CHECK:   %[[L0SH1:.+]] = vector.shuffle %[[C]], %[[D]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+//       CHECK:   %[[L0SH2:.+]] = vector.shuffle %[[E]], %[[F]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+//       CHECK:   %[[L0SH3:.+]] = vector.shuffle %[[G]], %[[H]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+//       CHECK:   %[[L0SH4:.+]] = vector.shuffle %[[I]], %[[J]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+//       CHECK:   %[[L0SH5:.+]] = vector.shuffle %[[K]], %[[L]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+//       CHECK:   %[[L0SH6:.+]] = vector.shuffle %[[M]], %[[N]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+//       CHECK:   %[[L0SH7:.+]] = vector.shuffle %[[O]], %[[P]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+//       CHECK:   %[[L0SH8:.+]] = vector.shuffle %[[Q]], %[[R]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+//       CHECK:   %[[L0SH9:.+]] = vector.shuffle %[[S]], %[[T]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+//       CHECK:   %[[L0SH10:.+]] = vector.shuffle %[[U]], %[[V]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+//       CHECK:   %[[L0SH11:.+]] = vector.shuffle %[[W]], %[[X]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+//       CHECK:   %[[L0SH12:.+]] = vector.shuffle %[[Y]], %[[Z]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+//       CHECK:   %[[L0SH13:.+]] = vector.shuffle %[[AA]], %[[AB]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+//       CHECK:   %[[L0SH14:.+]] = vector.shuffle %[[AC]], %[[AD]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+//       CHECK:   %[[L0SH15:.+]] = vector.shuffle %[[AE]], %[[AF]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+//       CHECK:   %[[L0SH16:.+]] = vector.shuffle %[[AG]], %[[AH]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+//       CHECK:   %[[L0SH17:.+]] = vector.shuffle %[[AI]], %[[AJ]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+//       CHECK:   %[[L0SH18:.+]] = vector.shuffle %[[AK]], %[[AL]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+//       CHECK:   %[[L0SH19:.+]] = vector.shuffle %[[AM]], %[[AN]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+//       CHECK:   %[[L0SH20:.+]] = vector.shuffle %[[AO]], %[[AP]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+//       CHECK:   %[[L0SH21:.+]] = vector.shuffle %[[AQ]], %[[AR]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+//       CHECK:   %[[L0SH22:.+]] = vector.shuffle %[[AS]], %[[AT]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+//       CHECK:   %[[L0SH23:.+]] = vector.shuffle %[[AU]], %[[AV]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+//       CHECK:   %[[L0SH24:.+]] = vector.shuffle %[[AW]], %[[AX]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+//       CHECK:   %[[L0SH25:.+]] = vector.shuffle %[[AY]], %[[AZ]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+//       CHECK:   %[[L0SH26:.+]] = vector.shuffle %[[BA]], %[[BB]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+//       CHECK:   %[[L0SH27:.+]] = vector.shuffle %[[BC]], %[[BD]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+//       CHECK:   %[[L0SH28:.+]] = vector.shuffle %[[BE]], %[[BF]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+//       CHECK:   %[[L0SH29:.+]] = vector.shuffle %[[BG]], %[[BH]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+//       CHECK:   %[[L0SH30:.+]] = vector.shuffle %[[BI]], %[[BJ]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+//       CHECK:   %[[L0SH31:.+]] = vector.shuffle %[[BK]], %[[BL]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
+//       CHECK:   %[[L1SH0:.+]] = vector.shuffle %[[L0SH0]], %[[L0SH1]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
+//       CHECK:   %[[L1SH1:.+]] = vector.shuffle %[[L0SH2]], %[[L0SH3]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
+//       CHECK:   %[[L1SH2:.+]] = vector.shuffle %[[L0SH4]], %[[L0SH5]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
+//       CHECK:   %[[L1SH3:.+]] = vector.shuffle %[[L0SH6]], %[[L0SH7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
+//       CHECK:   %[[L1SH4:.+]] = vector.shuffle %[[L0SH8]], %[[L0SH9]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
+//       CHECK:   %[[L1SH5:.+]] = vector.shuffle %[[L0SH10]], %[[L0SH11]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
+//       CHECK:   %[[L1SH6:.+]] = vector.shuffle %[[L0SH12]], %[[L0SH13]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
+//       CHECK:   %[[L1SH7:.+]] = vector.shuffle %[[L0SH14]], %[[L0SH15]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
+//       CHECK:   %[[L1SH8:.+]] = vector.shuffle %[[L0SH16]], %[[L0SH17]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
+//       CHECK:   %[[L1SH9:.+]] = vector.shuffle %[[L0SH18]], %[[L0SH19]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
+//       CHECK:   %[[L1SH10:.+]] = vector.shuffle %[[L0SH20]], %[[L0SH21]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
+//       CHECK:   %[[L1SH11:.+]] = vector.shuffle %[[L0SH22]], %[[L0SH23]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
+//       CHECK:   %[[L1SH12:.+]] = vector.shuffle %[[L0SH24]], %[[L0SH25]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
+//       CHECK:   %[[L1SH13:.+]] = vector.shuffle %[[L0SH26]], %[[L0SH27]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
+//       CHECK:   %[[L1SH14:.+]] = vector.shuffle %[[L0SH28]], %[[L0SH29]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
+//       CHECK:   %[[L1SH15:.+]] = vector.shuffle %[[L0SH30]], %[[L0SH31]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
+//       CHECK:   %[[L2SH0:.+]] = vector.shuffle %[[L1SH0]], %[[L1SH1]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+//       CHECK:   %[[L2SH1:.+]] = vector.shuffle %[[L1SH2]], %[[L1SH3]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+//       CHECK:   %[[L2SH2:.+]] = vector.shuffle %[[L1SH4]], %[[L1SH5]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+//       CHECK:   %[[L2SH3:.+]] = vector.shuffle %[[L1SH6]], %[[L1SH7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+//       CHECK:   %[[L2SH4:.+]] = vector.shuffle %[[L1SH8]], %[[L1SH9]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+//       CHECK:   %[[L2SH5:.+]] = vector.shuffle %[[L1SH10]], %[[L1SH11]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+//       CHECK:   %[[L2SH6:.+]] = vector.shuffle %[[L1SH12]], %[[L1SH13]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+//       CHECK:   %[[L2SH7:.+]] = vector.shuffle %[[L1SH14]], %[[L1SH15]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+//       CHECK:   %[[L3SH0:.+]] = vector.shuffle %[[L2SH0]], %[[L2SH1]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf32>, vector<32xf32>
+//       CHECK:   %[[L3SH1:.+]] = vector.shuffle %[[L2SH2]], %[[L2SH3]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf32>, vector<32xf32>
+//       CHECK:   %[[L3SH2:.+]] = vector.shuffle %[[L2SH4]], %[[L2SH5]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf32>, vector<32xf32>
+//       CHECK:   %[[L3SH3:.+]] = vector.shuffle %[[L2SH6]], %[[L2SH7]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63] : vector<32xf32>, vector<32xf32>
+//       CHECK:   %[[L4SH0:.+]] = vector.shuffle %[[L3SH0]], %[[L3SH1]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127] : vector<64xf32>, vector<64xf32>
+//       CHECK:   %[[L4SH1:.+]] = vector.shuffle %[[L3SH2]], %[[L3SH3]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127] : vector<64xf32>, vector<64xf32>
+//       CHECK:   %[[L5SH0:.+]] = vector.shuffle %[[L4SH0]], %[[L4SH1]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255] : vector<128xf32>, vector<128xf32>
+//       CHECK:   return %[[L5SH0]] : vector<256xf32>
+
+// -----
+
+func.func @to_from_elements_shuffle_tree_arbitrary_4x4_to_16(%a: vector<4xf32>,
+                                                             %b: vector<4xf32>,
+                                                             %c: vector<4xf32>,
+                                                             %d: vector<4xf32>) -> vector<16xf32> {
+  %0:4 = vector.to_elements %a : vector<4xf32>
+  %1:4 = vector.to_elements %b : vector<4xf32>
+  %2:4 = vector.to_elements %c : vector<4xf32>
+  %3:4 = vector.to_elements %d : vector<4xf32>
+  %4 = vector.from_elements %3#3, %0#0, %2#2, %1#1, %3#0, %2#1, %0#3, %1#2, %0#1, %3#2, %1#0, %2#3, %1#3, %0#2, %3#1, %2#0 : vector<16xf32>
+  return %4 : vector<16xf32>
+}
+
+// TODO: Implement mask compression to reduce the number of intermediate poison values.
+
+// CHECK-LABEL: func @to_from_elements_shuffle_tree_arbitrary_4x4_to_16(
+//  CHECK-SAME:     %[[A:.*]]: vector<4xf32>, %[[B:.*]]: vector<4xf32>, %[[C:.*]]: vector<4xf32>, %[[D:.*]]: vector<4xf32>
+//       CHECK:   %[[L0SH0:.*]] = vector.shuffle %[[D]], %[[A]] [3, 4, -1, -1, 0, -1, 7, -1, 5, 2, -1, -1, -1, 6, 1] : vector<4xf32>, vector<4xf32>
+//       CHECK:   %[[L0SH1:.*]] = vector.shuffle %[[C]], %[[B]] [2, 5, -1, 1, -1, 6, -1, -1, 4, 3, 7, -1, -1, 0, -1] : vector<4xf32>, vector<4xf32>
+//       CHECK:   %[[L1SH0:.*]] = vector.shuffle %[[L0SH0]], %[[L0SH1]] [0, 1, 15, 16, 4, 18, 6, 20, 8, 9, 23, 24, 25, 13, 14, 28] : vector<15xf32>, vector<15xf32>
+//       CHECK:   return %[[L1SH0]] : vector<16xf32>
+
+// -----
+
+func.func @to_from_elements_shuffle_tree_arbitrary_3x4_to_12(%a: vector<4xf32>,
+                                                             %b: vector<4xf32>,
+                                                             %c: vector<4xf32>) -> vector<12xf32> {
+  %0:4 = vector.to_elements %a : vector<4xf32>
+  %1:4 = vector.to_elements %b : vector<4xf32>
+  %2:4 = vector.to_elements %c : vector<4xf32>
+  %3 = vector.from_elements %0#2, %1#1, %2#0, %0#1, %1#0, %2#2, %0#0, %1#3, %2#3, %0#3, %1#2, %2#1 : vector<12xf32>
+  return %3 : vector<12xf32>
+}
+
+// TODO: Implement mask compression to reduce the number of intermediate poison values.
+
+// CHECK-LABEL: func @to_from_elements_shuffle_tree_arbitrary_3x4_to_12(
+//  CHECK-SAME:     %[[A:.*]]: vector<4xf32>, %[[B:.*]]: vector<4xf32>, %[[C:.*]]: vector<4xf32>
+//       CHECK:   %[[L0SH0:.*]] = vector.shuffle %[[A]], %[[B]] [2, 5, -1, 1, 4, -1, 0, 7, -1, 3, 6] : vector<4xf32>, vector<4xf32>
+//       CHECK:   %[[L0SH1:.*]] = vector.shuffle %[[C]], %[[C]] [0, -1, -1, 2, -1, -1, 3, -1, -1, 1, -1] : vector<4xf32>, vector<4xf32>
+//       CHECK:   %[[L1SH0:.*]] = vector.shuffle %[[L0SH0]], %[[L0SH1]] [0, 1, 11, 3, 4, 14, 6, 7, 17, 9, 10, 20] : vector<11xf32>, vector<11xf32>
+//       CHECK:   return %[[L1SH0]] : vector<12xf32>
+
+// -----
+
+func.func @to_from_elements_shuffle_tree_arbitrary_3x5_to_9(%a: vector<5xf32>,
+                                                            %b: vector<5xf32>,
+                                                            %c: vector<5xf32>) -> vector<9xf32> {
+  %0:5 = vector.to_elements %a : vector<5xf32>
+  %1:5 = vector.to_elements %b : vector<5xf32>
+  %2:5 = vector.to_elements %c : vector<5xf32>
+  %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2, %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+  return %3 : vector<9xf32>
+}
+
+// TODO: Implement mask compression to reduce the number of intermediate poison values.
+
+// CHECK-LABEL: func @to_from_elements_shuffle_tree_arbitrary_3x5_to_9(
+//  CHECK-SAME:     %[[A:.*]]: vector<5xf32>, %[[B:.*]]: vector<5xf32>, %[[C:.*]]: vector<5xf32>
+//       CHECK:   %[[L0SH0:.*]] = vector.shuffle %[[C]], %[[B]] [2, 6, -1, -1, 7, 2, 0, 6] : vector<5xf32>, vector<5xf32>
+//       CHECK:   %[[L0SH1:.*]] = vector.shuffle %[[A]], %[[A]] [1, 1, -1, -1, -1, -1, 4, -1] : vector<5xf32>, vector<5xf32>
+//       CHECK:   %[[L1SH0:.*]] = vector.shuffle %[[L0SH0]], %[[L0SH1]] [0, 1, 8, 9, 4, 5, 6, 7, 14] : vector<8xf32>, vector<8xf32>
+//       CHECK:   return %[[L1SH0]] : vector<9xf32>
+
+// -----
+
+func.func @to_from_elements_shuffle_tree_broadcast_4x2_to_32(%a: vector<2xf32>,
+                                                             %b: vector<2xf32>,
+                                                             %c: vector<2xf32>,
+                                                             %d: vector<2xf32>) -> vector<32xf32> {
+  %0:2 = vector.to_elements %a : vector<2xf32>
+  %1:2 = vector.to_elements %b : vector<2xf32>
+  %2:2 = vector.to_elements %c : vector<2xf32>
+  %3:2 = vector.to_elements %d : vector<2xf32>
+  %4 = vector.from_elements %0#0, %0#0, %0#0, %0#0, %0#1, %0#1, %0#1, %0#1,
+                            %1#0, %1#0, %1#0, %1#0, %1#1, %1#1, %1#1, %1#1,
+                            %2#0, %2#0, %2#0, %2#0, %2#1, %2#1, %2#1, %2#1,
+                            %3#0, %3#0, %3#0, %3#0, %3#1, %3#1, %3#1, %3#1 : vector<32xf32>
+  return %4 : vector<32xf32>
+}
+
+// CHECK-LABEL: func @to_from_elements_shuffle_tree_broadcast_4x2_to_32(
+//  CHECK-SAME:     %[[A:.*]]: vector<2xf32>, %[[B:.*]]: vector<2xf32>, %[[C:.*]]: vector<2xf32>, %[[D:.*]]: vector<2xf32>
+      // CHECK:   %[[L0SH0:.*]] = vector.shuffle %[[A]], %[[B]] [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3] : vector<2xf32>, vector<2xf32>
+      // CHECK:   %[[L0SH1:.*]] = vector.shuffle %[[C]], %[[D]] [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3] : vector<2xf32>, vector<2xf32>
+      // CHECK:   %[[L1SH0:.*]] = vector.shuffle %[[L0SH0]], %[[L0SH1]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
+      // CHECK:   return %[[L1SH0]] : vector<32xf32>

>From 810c85a3d3aaabf93adb0c40c6aef0d0aa52be4b Mon Sep 17 00:00:00 2001
From: Diego Caballero <dcaballero at nvidia.com>
Date: Tue, 1 Jul 2025 05:05:30 +0000
Subject: [PATCH 2/2] Feedback

---
 ...LowerVectorToFromElementsToShuffleTree.cpp | 161 ++++++++++--------
 ...m-elements-to-shuffle-tree-transforms.mlir | 103 +++++++++--
 2 files changed, 181 insertions(+), 83 deletions(-)

diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
index 53728d6dbe2a3..504103529cdcb 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
@@ -124,7 +124,6 @@ constexpr unsigned kMaxUnsigned = std::numeric_limits<unsigned>::max();
 ///
 /// TODO: Implement mask compression to reduce the number of intermediate poison
 /// values.
-///
 class VectorShuffleTreeBuilder {
 public:
   VectorShuffleTreeBuilder() = delete;
@@ -142,8 +141,8 @@ class VectorShuffleTreeBuilder {
 
 private:
   // IR input information.
-  FromElementsOp fromElementsOp;
-  SmallVector<ToElementsOp> toElementsDefs;
+  FromElementsOp fromElemsOp;
+  SmallVector<ToElementsOp> toElemsDefs;
 
   // Shuffle tree configuration.
   unsigned numLevels;
@@ -162,16 +161,19 @@ class VectorShuffleTreeBuilder {
 
 VectorShuffleTreeBuilder::VectorShuffleTreeBuilder(
     FromElementsOp fromElemOp, ArrayRef<ToElementsOp> toElemDefs)
-    : fromElementsOp(fromElemOp), toElementsDefs(toElemDefs) {
-
-  assert(fromElementsOp && "from_elements op is required");
-  assert(!toElementsDefs.empty() && "At least one to_elements op is required");
+    : fromElemsOp(fromElemOp), toElemsDefs(toElemDefs) {
+  assert(fromElemsOp && "from_elements op is required");
+  assert(!toElemsDefs.empty() && "At least one to_elements op is required");
+}
 
-  // Duplicate the last vector if the number of `vector.to_elements` is odd to
-  // simplify the shuffle tree algorithm.
-  if (toElementsDefs.size() % 2 != 0) {
-    toElementsDefs.push_back(toElementsDefs.back());
-  }
+/// Duplicate the last operation, value or interval if the total number of them
+/// is odd. This is useful to simplify the shuffle tree algorithm given that
+/// vectors are shuffled in pairs and duplication would lead to the last shuffle
+/// to have a single (duplicated) input vector.
+template <typename T>
+static void duplicateLastIfOdd(SmallVectorImpl<T> &values) {
+  if (values.size() % 2 != 0)
+    values.push_back(values.back());
 }
 
 // ===--------------------------------------------------------------------===//
@@ -207,20 +209,20 @@ void VectorShuffleTreeBuilder::computeInputVectorIntervals() {
   // Map `vector.to_elements` ops to their ordinal position in the
   // `vector.from_elements` operand list. Make sure duplicated
   // `vector.to_elements` ops are mapped to the its first occurrence.
-  DenseMap<ToElementsOp, unsigned> toElementsToInputOrdinal;
-  for (const auto &[idx, toElementsOp] : llvm::enumerate(toElementsDefs))
-    toElementsToInputOrdinal.insert({toElementsOp, idx});
+  DenseMap<ToElementsOp, unsigned> toElemsToInputOrdinal;
+  for (const auto &[idx, toElemsOp] : llvm::enumerate(toElemsDefs))
+    toElemsToInputOrdinal.insert({toElemsOp, idx});
 
   // Compute intervals for each input vector in the shuffle tree. The first
   // level computation is special-cased to keep the implementation simpler.
 
-  SmallVector<Interval> firstLevelIntervals(toElementsDefs.size(),
+  SmallVector<Interval> firstLevelIntervals(toElemsDefs.size(),
                                             {kMaxUnsigned, kMaxUnsigned});
 
   for (const auto &[idx, element] :
-       llvm::enumerate(fromElementsOp.getElements())) {
-    auto toElementsOp = cast<ToElementsOp>(element.getDefiningOp());
-    unsigned inputIdx = toElementsToInputOrdinal[toElementsOp];
+       llvm::enumerate(fromElemsOp.getElements())) {
+    auto toElemsOp = cast<ToElementsOp>(element.getDefiningOp());
+    unsigned inputIdx = toElemsToInputOrdinal[toElemsOp];
     Interval &currentInterval = firstLevelIntervals[inputIdx];
 
     // Set lower bound to the first occurrence of the `vector.to_elements`.
@@ -231,19 +233,13 @@ void VectorShuffleTreeBuilder::computeInputVectorIntervals() {
     currentInterval.second = idx;
   }
 
-  // If the number of `vector.to_elements` is odd and the last op was
-  // duplicated, the interval for the duplicated op was not computed in the
-  // previous step as all the input occurrences were mapped to the original op.
-  // We copy the interval of the original op to the interval of the duplicated
-  // op manually.
-  if (firstLevelIntervals.back().second == kMaxUnsigned)
-    firstLevelIntervals.back() = *std::prev(firstLevelIntervals.end(), 2);
-
+  duplicateLastIfOdd(toElemsDefs);
+  duplicateLastIfOdd(firstLevelIntervals);
   inputIntervalsPerLevel.push_back(std::move(firstLevelIntervals));
 
   // Compute intervals for the remaining levels.
   unsigned outputNumElements =
-      cast<VectorType>(fromElementsOp.getResult().getType()).getNumElements();
+      cast<VectorType>(fromElemsOp.getResult().getType()).getNumElements();
   for (unsigned level = 1; level < numLevels; ++level) {
     const auto &prevLevelIntervals = inputIntervalsPerLevel[level - 1];
     SmallVector<Interval> currentLevelIntervals(
@@ -265,6 +261,7 @@ void VectorShuffleTreeBuilder::computeInputVectorIntervals() {
                    outputNumElements - 1);
     }
 
+    duplicateLastIfOdd(currentLevelIntervals);
     inputIntervalsPerLevel.push_back(std::move(currentLevelIntervals));
   }
 }
@@ -311,9 +308,9 @@ void VectorShuffleTreeBuilder::dump() {
     ++indLv;
     llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Inputs:\n";
     ++indLv;
-    for (const auto &toElementsOp : toElementsDefs)
-      llvm::dbgs() << llvm::indent(indLv, kIndScale) << toElementsOp << "\n";
-    llvm::dbgs() << llvm::indent(indLv, kIndScale) << fromElementsOp << "\n\n";
+    for (const auto &toElemsOp : toElemsDefs)
+      llvm::dbgs() << llvm::indent(indLv, kIndScale) << toElemsOp << "\n";
+    llvm::dbgs() << llvm::indent(indLv, kIndScale) << fromElemsOp << "\n\n";
     --indLv;
 
     llvm::dbgs() << llvm::indent(indLv, kIndScale)
@@ -366,9 +363,7 @@ void VectorShuffleTreeBuilder::dump() {
 /// corresponding utility functions.
 LogicalResult VectorShuffleTreeBuilder::computeShuffleTree() {
   // Initialize shuffle tree information based on its size.
-  assert(toElementsDefs.size() > 1 &&
-         "At least two 'vector.to_elements' ops are required");
-  numLevels = llvm::Log2_64(toElementsDefs.size());
+  numLevels = std::max(1u, llvm::Log2_64_Ceil(toElemsDefs.size()));
   vectorSizePerLevel.resize(numLevels, 0);
   inputIntervalsPerLevel.reserve(numLevels);
 
@@ -402,17 +397,18 @@ LogicalResult VectorShuffleTreeBuilder::computeShuffleTree() {
 ///   %2_1 = PermutationShuffleMask(%2, %1) = [2, 6, -1, -1, 7, 2, 0, 6]
 ///   %0_0 = PermutationShuffleMask(%0, %0) = [1, 1, -1, -1, -1, -1, 4, -1]
 ///
-/// TODO: Implement mask compression.
+/// TODO: Implement mask compression to reduce the number of intermediate poison
+/// values.
 static SmallVector<int64_t> computePermutationShuffleMask(
     ToElementsOp toElementOp0, const Interval &interval0,
     ToElementsOp toElementOp1, const Interval &interval1,
-    FromElementsOp fromElementsOp, unsigned outputVectorSize) {
+    FromElementsOp fromElemsOp, unsigned outputVectorSize) {
   SmallVector<int64_t> mask(outputVectorSize, ShuffleOp::kPoisonIndex);
   unsigned inputVectorSize =
       toElementOp0.getSource().getType().getNumElements();
 
   for (const auto &[inputIdx, element] :
-       llvm::enumerate(fromElementsOp.getElements())) {
+       llvm::enumerate(fromElemsOp.getElements())) {
     auto currentToElemOp = cast<ToElementsOp>(element.getDefiningOp());
     // Match `vector.from_elements` operands to the two input ops.
     if (currentToElemOp != toElementOp0 && currentToElemOp != toElementOp1)
@@ -476,8 +472,8 @@ static SmallVector<int64_t> computePermutationShuffleMask(
 ///   // Level 1, vector length = 9
 ///   PropagationShuffleMask(%2_1, %0_0) = [0, 1, 8, 9, 4, 5, 6, 7, 14]
 ///
-/// TODO: Implement mask compression.
-///
+/// TODO: Implement mask compression to reduce the number of intermediate poison
+/// values.
 static SmallVector<int64_t> computePropagationShuffleMask(
     ShuffleOp lhsShuffleOp, const Interval &lhsInterval, ShuffleOp rhsShuffleOp,
     const Interval &rhsInterval, unsigned outputVectorSize) {
@@ -487,6 +483,7 @@ static SmallVector<int64_t> computePropagationShuffleMask(
   assert(inputVectorSize == rhsShuffleMask.size() &&
          "Expected both shuffle masks to have the same size");
 
+  bool hasSameInput = lhsShuffleOp == rhsShuffleOp;
   unsigned lhsRhsOffset = rhsInterval.first - lhsInterval.first;
   SmallVector<int64_t> mask(outputVectorSize, ShuffleOp::kPoisonIndex);
 
@@ -497,6 +494,9 @@ static SmallVector<int64_t> computePropagationShuffleMask(
     if (lhsShuffleMask[i] != ShuffleOp::kPoisonIndex)
       mask[i] = i;
 
+    if (hasSameInput)
+      continue;
+
     unsigned rhsIdx = i + lhsRhsOffset;
     if (rhsShuffleMask[i] != ShuffleOp::kPoisonIndex) {
       assert(rhsIdx < outputVectorSize && "RHS index out of bounds");
@@ -565,15 +565,19 @@ Value VectorShuffleTreeBuilder::generateShuffleTree(PatternRewriter &rewriter) {
 
   // Initialize work list with the `vector.to_elements` sources.
   SmallVector<Value> levelInputs;
-  llvm::transform(
-      toElementsDefs, std::back_inserter(levelInputs),
-      [](ToElementsOp toElementsOp) { return toElementsOp.getSource(); });
+  llvm::transform(toElemsDefs, std::back_inserter(levelInputs),
+                  [](ToElementsOp toElemsOp) { return toElemsOp.getSource(); });
+  // TODO: Check that every pair of input has the same vector size. Otherwise,
+  // promote the narrower one to the wider one.
 
   // Build shuffle tree by combining pairs of vectors.
-  Location loc = fromElementsOp.getLoc();
+  Location loc = fromElemsOp.getLoc();
   unsigned currentLevel = 0;
   for (const auto &[levelVectorSize, inputIntervals] :
        llvm::zip_equal(vectorSizePerLevel, inputIntervalsPerLevel)) {
+
+    duplicateLastIfOdd(levelInputs);
+
     LLVM_DEBUG(llvm::dbgs()
                << llvm::indent(1, kIndScale) << "* Processing level "
                << currentLevel << " (vector size: " << levelVectorSize
@@ -593,8 +597,8 @@ Value VectorShuffleTreeBuilder::generateShuffleTree(PatternRewriter &rewriter) {
       SmallVector<int64_t> shuffleMask;
       if (currentLevel == 0) {
         shuffleMask = computePermutationShuffleMask(
-            toElementsDefs[i], lhsInterval, toElementsDefs[i + 1], rhsInterval,
-            fromElementsOp, levelVectorSize);
+            toElemsDefs[i], lhsInterval, toElemsDefs[i + 1], rhsInterval,
+            fromElemsOp, levelVectorSize);
       } else {
         auto lhsShuffleOp = cast<ShuffleOp>(lhsVector.getDefiningOp());
         auto rhsShuffleOp = cast<ShuffleOp>(rhsVector.getDefiningOp());
@@ -621,17 +625,17 @@ Value VectorShuffleTreeBuilder::generateShuffleTree(PatternRewriter &rewriter) {
 /// returned in order of appearance in the `vector.from_elements`'s operand
 /// list.
 static LogicalResult
-getToElementsDefiningOps(FromElementsOp fromElementsOp,
-                         SmallVectorImpl<ToElementsOp> &toElementsDefs) {
-  SetVector<ToElementsOp> toElementsDefsSet;
-  for (Value element : fromElementsOp.getElements()) {
-    auto toElementsOp = element.getDefiningOp<ToElementsOp>();
-    if (!toElementsOp)
+getToElementsDefiningOps(FromElementsOp fromElemsOp,
+                         SmallVectorImpl<ToElementsOp> &toElemsDefs) {
+  SetVector<ToElementsOp> toElemsDefsSet;
+  for (Value element : fromElemsOp.getElements()) {
+    auto toElemsOp = element.getDefiningOp<ToElementsOp>();
+    if (!toElemsOp)
       return failure();
-    toElementsDefsSet.insert(toElementsOp);
+    toElemsDefsSet.insert(toElemsOp);
   }
 
-  toElementsDefs.assign(toElementsDefsSet.begin(), toElementsDefsSet.end());
+  toElemsDefs.assign(toElemsDefsSet.begin(), toElemsDefsSet.end());
   return success();
 }
 
@@ -642,30 +646,53 @@ struct ToFromElementsToShuffleTreeRewrite final
 
   using OpRewritePattern::OpRewritePattern;
 
-  LogicalResult matchAndRewrite(vector::FromElementsOp fromElementsOp,
+  LogicalResult matchAndRewrite(vector::FromElementsOp fromElemsOp,
                                 PatternRewriter &rewriter) const override {
-    VectorType resultType = fromElementsOp.getType();
-    if (resultType.getRank() != 1 || resultType.isScalable())
-      return failure();
+    VectorType resultType = fromElemsOp.getType();
+    if (resultType.getRank() != 1)
+      return rewriter.notifyMatchFailure(
+          fromElemsOp, "Multi-dimensional vectors are not supported yet");
+    if (resultType.isScalable())
+      return rewriter.notifyMatchFailure(
+          fromElemsOp,
+          "'vector.from_elements' does not support scalable vectors");
+
+    SmallVector<ToElementsOp> toElemsDefs;
+    if (failed(getToElementsDefiningOps(fromElemsOp, toElemsDefs)))
+      return rewriter.notifyMatchFailure(fromElemsOp, "unsupported sources");
+
+    int64_t numElements =
+        toElemsDefs.front().getSource().getType().getNumElements();
+    for (ToElementsOp toElemsOp : toElemsDefs) {
+      if (toElemsOp.getSource().getType().getNumElements() != numElements)
+        return rewriter.notifyMatchFailure(
+            fromElemsOp, "unsupported sources with different vector sizes");
+    }
 
-    SmallVector<ToElementsOp> toElementsDefs;
-    if (failed(getToElementsDefiningOps(fromElementsOp, toElementsDefs)))
-      return failure();
+    if (llvm::any_of(toElemsDefs, [](ToElementsOp toElemsOp) {
+          return !toElemsOp.getSource().getType().hasRank();
+        })) {
+      return rewriter.notifyMatchFailure(fromElemsOp,
+                                         "0-D vectors are not supported");
+    }
 
     // Avoid generating a shuffle tree for trivial `vector.to_elements` ->
     // `vector.from_elements` forwarding cases that do not require shuffling.
-    if (toElementsDefs.size() == 1) {
-      ToElementsOp toElementsOp0 = toElementsDefs.front();
-      if (llvm::equal(fromElementsOp.getElements(), toElementsOp0.getResults()))
-        return failure();
+    if (toElemsDefs.size() == 1) {
+      ToElementsOp toElemsOp0 = toElemsDefs.front();
+      if (llvm::equal(fromElemsOp.getElements(), toElemsOp0.getResults())) {
+        return rewriter.notifyMatchFailure(
+            fromElemsOp, "trivial forwarding case does not require shuffling");
+      }
     }
 
-    VectorShuffleTreeBuilder shuffleTreeBuilder(fromElementsOp, toElementsDefs);
+    VectorShuffleTreeBuilder shuffleTreeBuilder(fromElemsOp, toElemsDefs);
     if (failed(shuffleTreeBuilder.computeShuffleTree()))
-      return failure();
+      return rewriter.notifyMatchFailure(fromElemsOp,
+                                         "failed to compute shuffle tree");
 
     Value finalShuffle = shuffleTreeBuilder.generateShuffleTree(rewriter);
-    rewriter.replaceOp(fromElementsOp, finalShuffle);
+    rewriter.replaceOp(fromElemsOp, finalShuffle);
     return success();
   }
 };
diff --git a/mlir/test/Dialect/Vector/vector-tofrom-elements-to-shuffle-tree-transforms.mlir b/mlir/test/Dialect/Vector/vector-tofrom-elements-to-shuffle-tree-transforms.mlir
index 3dc579be12f0f..a8d3d5278e893 100644
--- a/mlir/test/Dialect/Vector/vector-tofrom-elements-to-shuffle-tree-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-tofrom-elements-to-shuffle-tree-transforms.mlir
@@ -4,13 +4,27 @@
 // where L# refers to the level of the tree the shuffle belongs to, and SH# refers to
 // the shuffle index within that level.
 
-func.func @to_from_elements_single_input_shuffle(%a: vector<8xf32>) -> vector<8xf32> {
+func.func @trivial_forwarding(%a: vector<8xf32>) -> vector<8xf32> {
+  %0:8 = vector.to_elements %a : vector<8xf32>
+  %1 = vector.from_elements %0#0, %0#1, %0#2, %0#3, %0#4, %0#5, %0#6, %0#7 : vector<8xf32>
+  return %1 : vector<8xf32>
+}
+
+// No shuffle tree needed for trivial forwarding case.
+
+// CHECK-LABEL: func @trivial_forwarding(
+//  CHECK-SAME:     %[[A:.*]]: vector<8xf32>
+//       CHECK:   return %[[A]] : vector<8xf32>
+
+// -----
+
+func.func @single_input_shuffle(%a: vector<8xf32>) -> vector<8xf32> {
   %0:8 = vector.to_elements %a : vector<8xf32>
   %1 = vector.from_elements %0#7, %0#0, %0#6, %0#1, %0#5, %0#2, %0#4, %0#3 : vector<8xf32>
   return %1 : vector<8xf32>
 }
 
-// CHECK-LABEL: func @to_from_elements_single_input_shuffle(
+// CHECK-LABEL: func @single_input_shuffle(
 //  CHECK-SAME:     %[[A:.*]]: vector<8xf32>
       // CHECK:   %[[L0SH0:.*]] = vector.shuffle %[[A]], %[[A]] [7, 0, 6, 1, 5, 2, 4, 3] : vector<8xf32>, vector<8xf32>
       // CHECK:   return %[[L0SH0]]
@@ -32,7 +46,7 @@ func.func @from_elements_to_elements_single_shuffle(%a: vector<8xf32>,
 
 // -----
 
-func.func @to_from_elements_shuffle_tree_concat_4x8_to_32(%a: vector<8xf32>,
+func.func @shuffle_tree_concat_4x8_to_32(%a: vector<8xf32>,
                                                           %b: vector<8xf32>,
                                                           %c: vector<8xf32>,
                                                           %d: vector<8xf32>) -> vector<32xf32> {
@@ -47,7 +61,7 @@ func.func @to_from_elements_shuffle_tree_concat_4x8_to_32(%a: vector<8xf32>,
   return %4 : vector<32xf32>
 }
 
-// CHECK-LABEL: func @to_from_elements_shuffle_tree_concat_4x8_to_32(
+// CHECK-LABEL: func @shuffle_tree_concat_4x8_to_32(
 //  CHECK-SAME:     %[[A:.*]]: vector<8xf32>, %[[B:.*]]: vector<8xf32>, %[[C:.*]]: vector<8xf32>, %[[D:.*]]: vector<8xf32>
 //       CHECK:   %[[L0SH0:.*]] = vector.shuffle %[[A]], %[[B]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
 //       CHECK:   %[[L0SH1:.*]] = vector.shuffle %[[C]], %[[D]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<8xf32>, vector<8xf32>
@@ -56,7 +70,7 @@ func.func @to_from_elements_shuffle_tree_concat_4x8_to_32(%a: vector<8xf32>,
 
 // -----
 
-func.func @to_from_elements_shuffle_tree_concat_3x4_to_12(%a: vector<4xf32>,
+func.func @shuffle_tree_concat_3x4_to_12(%a: vector<4xf32>,
                                                           %b: vector<4xf32>,
                                                           %c: vector<4xf32>) -> vector<12xf32> {
   %0:4 = vector.to_elements %a : vector<4xf32>
@@ -66,7 +80,7 @@ func.func @to_from_elements_shuffle_tree_concat_3x4_to_12(%a: vector<4xf32>,
   return %3 : vector<12xf32>
 }
 
-// CHECK-LABEL: func @to_from_elements_shuffle_tree_concat_3x4_to_12(
+// CHECK-LABEL: func @shuffle_tree_concat_3x4_to_12(
 //  CHECK-SAME:     %[[A:.*]]: vector<4xf32>, %[[B:.*]]: vector<4xf32>, %[[C:.*]]: vector<4xf32>
 //       CHECK:   %[[L0SH0:.*]] = vector.shuffle %[[A]], %[[B]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
 //       CHECK:   %[[L0SH1:.*]] = vector.shuffle %[[C]], %[[C]] [0, 1, 2, 3, -1, -1, -1, -1] : vector<4xf32>, vector<4xf32>
@@ -75,7 +89,7 @@ func.func @to_from_elements_shuffle_tree_concat_3x4_to_12(%a: vector<4xf32>,
 
 // -----
 
-func.func @to_from_elements_shuffle_tree_concat_64x4_256(
+func.func @shuffle_tree_concat_64x4_256(
   %a: vector<4xf32>, %b: vector<4xf32>, %c: vector<4xf32>, %d: vector<4xf32>,
   %e: vector<4xf32>, %f: vector<4xf32>, %g: vector<4xf32>, %h: vector<4xf32>,
   %i: vector<4xf32>, %j: vector<4xf32>, %k: vector<4xf32>, %l: vector<4xf32>,
@@ -172,7 +186,7 @@ func.func @to_from_elements_shuffle_tree_concat_64x4_256(
   return %64 : vector<256xf32>
 }
 
-// CHECK-LABEL: func.func @to_from_elements_shuffle_tree_concat_64x4_256(
+// CHECK-LABEL: func.func @shuffle_tree_concat_64x4_256(
 //  CHECK-SAME:     %[[A:.+]]: vector<4xf32>, %[[B:.+]]: vector<4xf32>, %[[C:.+]]: vector<4xf32>, %[[D:.+]]: vector<4xf32>, %[[E:.+]]: vector<4xf32>, %[[F:.+]]: vector<4xf32>, %[[G:.+]]: vector<4xf32>, %[[H:.+]]: vector<4xf32>, %[[I:.+]]: vector<4xf32>, %[[J:.+]]: vector<4xf32>, %[[K:.+]]: vector<4xf32>, %[[L:.+]]: vector<4xf32>, %[[M:.+]]: vector<4xf32>, %[[N:.+]]: vector<4xf32>, %[[O:.+]]: vector<4xf32>, %[[P:.+]]: vector<4xf32>, %[[Q:.+]]: vector<4xf32>, %[[R:.+]]: vector<4xf32>, %[[S:.+]]: vector<4xf32>, %[[T:.+]]: vector<4xf32>, %[[U:.+]]: vector<4xf32>, %[[V:.+]]: vector<4xf32>, %[[W:.+]]: vector<4xf32>, %[[X:.+]]: vector<4xf32>, %[[Y:.+]]: vector<4xf32>, %[[Z:.+]]: vector<4xf32>, %[[AA:.+]]: vector<4xf32>, %[[AB:.+]]: vector<4xf32>, %[[AC:.+]]: vector<4xf32>, %[[AD:.+]]: vector<4xf32>, %[[AE:.+]]: vector<4xf32>, %[[AF:.+]]: vector<4xf32>, %[[AG:.+]]: vector<4xf32>, %[[AH:.+]]: vector<4xf32>, %[[AI:.+]]: vector<4xf32>, %[[AJ:.+]]: vector<4xf32>, %[[AK:.+]]: vector<4xf32>, %[[AL:.+]]: vector<4xf32>, %[[AM:.+]]: vector<4xf32>, %[[AN:.+]]: vector<4xf32>, %[[AO:.+]]: vector<4xf32>, %[[AP:.+]]: vector<4xf32>, %[[AQ:.+]]: vector<4xf32>, %[[AR:.+]]: vector<4xf32>, %[[AS:.+]]: vector<4xf32>, %[[AT:.+]]: vector<4xf32>, %[[AU:.+]]: vector<4xf32>, %[[AV:.+]]: vector<4xf32>, %[[AW:.+]]: vector<4xf32>, %[[AX:.+]]: vector<4xf32>, %[[AY:.+]]: vector<4xf32>, %[[AZ:.+]]: vector<4xf32>, %[[BA:.+]]: vector<4xf32>, %[[BB:.+]]: vector<4xf32>, %[[BC:.+]]: vector<4xf32>, %[[BD:.+]]: vector<4xf32>, %[[BE:.+]]: vector<4xf32>, %[[BF:.+]]: vector<4xf32>, %[[BG:.+]]: vector<4xf32>, %[[BH:.+]]: vector<4xf32>, %[[BI:.+]]: vector<4xf32>, %[[BJ:.+]]: vector<4xf32>, %[[BK:.+]]: vector<4xf32>, %[[BL:.+]]: vector<4xf32>)
 //       CHECK:   %[[L0SH0:.+]] = vector.shuffle %[[A]], %[[B]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
 //       CHECK:   %[[L0SH1:.+]] = vector.shuffle %[[C]], %[[D]] [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
@@ -241,7 +255,7 @@ func.func @to_from_elements_shuffle_tree_concat_64x4_256(
 
 // -----
 
-func.func @to_from_elements_shuffle_tree_arbitrary_4x4_to_16(%a: vector<4xf32>,
+func.func @shuffle_tree_arbitrary_4x4_to_16(%a: vector<4xf32>,
                                                              %b: vector<4xf32>,
                                                              %c: vector<4xf32>,
                                                              %d: vector<4xf32>) -> vector<16xf32> {
@@ -255,7 +269,7 @@ func.func @to_from_elements_shuffle_tree_arbitrary_4x4_to_16(%a: vector<4xf32>,
 
 // TODO: Implement mask compression to reduce the number of intermediate poison values.
 
-// CHECK-LABEL: func @to_from_elements_shuffle_tree_arbitrary_4x4_to_16(
+// CHECK-LABEL: func @shuffle_tree_arbitrary_4x4_to_16(
 //  CHECK-SAME:     %[[A:.*]]: vector<4xf32>, %[[B:.*]]: vector<4xf32>, %[[C:.*]]: vector<4xf32>, %[[D:.*]]: vector<4xf32>
 //       CHECK:   %[[L0SH0:.*]] = vector.shuffle %[[D]], %[[A]] [3, 4, -1, -1, 0, -1, 7, -1, 5, 2, -1, -1, -1, 6, 1] : vector<4xf32>, vector<4xf32>
 //       CHECK:   %[[L0SH1:.*]] = vector.shuffle %[[C]], %[[B]] [2, 5, -1, 1, -1, 6, -1, -1, 4, 3, 7, -1, -1, 0, -1] : vector<4xf32>, vector<4xf32>
@@ -264,7 +278,7 @@ func.func @to_from_elements_shuffle_tree_arbitrary_4x4_to_16(%a: vector<4xf32>,
 
 // -----
 
-func.func @to_from_elements_shuffle_tree_arbitrary_3x4_to_12(%a: vector<4xf32>,
+func.func @shuffle_tree_arbitrary_3x4_to_12(%a: vector<4xf32>,
                                                              %b: vector<4xf32>,
                                                              %c: vector<4xf32>) -> vector<12xf32> {
   %0:4 = vector.to_elements %a : vector<4xf32>
@@ -276,7 +290,7 @@ func.func @to_from_elements_shuffle_tree_arbitrary_3x4_to_12(%a: vector<4xf32>,
 
 // TODO: Implement mask compression to reduce the number of intermediate poison values.
 
-// CHECK-LABEL: func @to_from_elements_shuffle_tree_arbitrary_3x4_to_12(
+// CHECK-LABEL: func @shuffle_tree_arbitrary_3x4_to_12(
 //  CHECK-SAME:     %[[A:.*]]: vector<4xf32>, %[[B:.*]]: vector<4xf32>, %[[C:.*]]: vector<4xf32>
 //       CHECK:   %[[L0SH0:.*]] = vector.shuffle %[[A]], %[[B]] [2, 5, -1, 1, 4, -1, 0, 7, -1, 3, 6] : vector<4xf32>, vector<4xf32>
 //       CHECK:   %[[L0SH1:.*]] = vector.shuffle %[[C]], %[[C]] [0, -1, -1, 2, -1, -1, 3, -1, -1, 1, -1] : vector<4xf32>, vector<4xf32>
@@ -285,7 +299,7 @@ func.func @to_from_elements_shuffle_tree_arbitrary_3x4_to_12(%a: vector<4xf32>,
 
 // -----
 
-func.func @to_from_elements_shuffle_tree_arbitrary_3x5_to_9(%a: vector<5xf32>,
+func.func @shuffle_tree_arbitrary_3x5_to_9(%a: vector<5xf32>,
                                                             %b: vector<5xf32>,
                                                             %c: vector<5xf32>) -> vector<9xf32> {
   %0:5 = vector.to_elements %a : vector<5xf32>
@@ -297,7 +311,7 @@ func.func @to_from_elements_shuffle_tree_arbitrary_3x5_to_9(%a: vector<5xf32>,
 
 // TODO: Implement mask compression to reduce the number of intermediate poison values.
 
-// CHECK-LABEL: func @to_from_elements_shuffle_tree_arbitrary_3x5_to_9(
+// CHECK-LABEL: func @shuffle_tree_arbitrary_3x5_to_9(
 //  CHECK-SAME:     %[[A:.*]]: vector<5xf32>, %[[B:.*]]: vector<5xf32>, %[[C:.*]]: vector<5xf32>
 //       CHECK:   %[[L0SH0:.*]] = vector.shuffle %[[C]], %[[B]] [2, 6, -1, -1, 7, 2, 0, 6] : vector<5xf32>, vector<5xf32>
 //       CHECK:   %[[L0SH1:.*]] = vector.shuffle %[[A]], %[[A]] [1, 1, -1, -1, -1, -1, 4, -1] : vector<5xf32>, vector<5xf32>
@@ -306,7 +320,7 @@ func.func @to_from_elements_shuffle_tree_arbitrary_3x5_to_9(%a: vector<5xf32>,
 
 // -----
 
-func.func @to_from_elements_shuffle_tree_broadcast_4x2_to_32(%a: vector<2xf32>,
+func.func @shuffle_tree_broadcast_4x2_to_32(%a: vector<2xf32>,
                                                              %b: vector<2xf32>,
                                                              %c: vector<2xf32>,
                                                              %d: vector<2xf32>) -> vector<32xf32> {
@@ -321,9 +335,66 @@ func.func @to_from_elements_shuffle_tree_broadcast_4x2_to_32(%a: vector<2xf32>,
   return %4 : vector<32xf32>
 }
 
-// CHECK-LABEL: func @to_from_elements_shuffle_tree_broadcast_4x2_to_32(
+// CHECK-LABEL: func @shuffle_tree_broadcast_4x2_to_32(
 //  CHECK-SAME:     %[[A:.*]]: vector<2xf32>, %[[B:.*]]: vector<2xf32>, %[[C:.*]]: vector<2xf32>, %[[D:.*]]: vector<2xf32>
       // CHECK:   %[[L0SH0:.*]] = vector.shuffle %[[A]], %[[B]] [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3] : vector<2xf32>, vector<2xf32>
       // CHECK:   %[[L0SH1:.*]] = vector.shuffle %[[C]], %[[D]] [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3] : vector<2xf32>, vector<2xf32>
       // CHECK:   %[[L1SH0:.*]] = vector.shuffle %[[L0SH0]], %[[L0SH1]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31] : vector<16xf32>, vector<16xf32>
       // CHECK:   return %[[L1SH0]] : vector<32xf32>
+
+// -----
+
+
+func.func @shuffle_tree_arbitrary_mixed_sizes(
+     %a : vector<2xf32>,
+     %b : vector<1xf32>,
+     %c : vector<3xf32>,
+     %d : vector<1xf32>,
+     %e : vector<5xf32>) -> vector<6xf32> {
+  %0:2 = vector.to_elements %a : vector<2xf32>
+  %1 = vector.to_elements %b : vector<1xf32>
+  %2:3 = vector.to_elements %c : vector<3xf32>
+  %3 = vector.to_elements %d : vector<1xf32>
+  %4:5 = vector.to_elements %e : vector<5xf32>
+  %5 = vector.from_elements %0#0, %2#0, %3, %4#0, %1, %4#3 : vector<6xf32>
+  return %5 : vector<6xf32>
+}
+
+// TODO: Support mixed vector sizes.
+
+//   CHECK-LABEL: func @shuffle_tree_arbitrary_mixed_sizes(
+// CHECK-COUNT-5:   vector.to_elements
+//        CHECK:    vector.from_elements
+
+// -----
+
+func.func @shuffle_tree_odd_intermediate_vectors(
+     %a : vector<2xf32>,
+     %b : vector<2xf32>,
+     %c : vector<2xf32>,
+     %d : vector<2xf32>,
+     %e : vector<2xf32>,
+     %f : vector<2xf32>) -> vector<6xf32> {
+  %0:2 = vector.to_elements %a : vector<2xf32>
+  %1:2 = vector.to_elements %b : vector<2xf32>
+  %2:2 = vector.to_elements %c : vector<2xf32>
+  %3:2 = vector.to_elements %d : vector<2xf32>
+  %4:2 = vector.to_elements %e : vector<2xf32>
+  %5:2 = vector.to_elements %f : vector<2xf32>
+  %6 = vector.from_elements %0#0, %1#1, %2#0, %3#1, %4#0, %5#1 : vector<6xf32>
+  return %6 : vector<6xf32>
+}
+
+// CHECK-LABEL: func @shuffle_tree_odd_intermediate_vectors(
+// CHECK-SAME:      %[[A:.*]]: vector<2xf32>, %[[B:.*]]: vector<2xf32>, %[[C:.*]]: vector<2xf32>, %[[D:.*]]: vector<2xf32>, %[[E:.*]]: vector<2xf32>, %[[F:.*]]: vector<2xf32>
+//       CHECK:   %[[L0SH0:.*]] = vector.shuffle %[[A]], %[[B]] [0, 3] : vector<2xf32>, vector<2xf32>
+//       CHECK:   %[[L0SH1:.*]] = vector.shuffle %[[C]], %[[D]] [0, 3] : vector<2xf32>, vector<2xf32>
+//       CHECK:   %[[L0SH2:.*]] = vector.shuffle %[[E]], %[[F]] [0, 3] : vector<2xf32>, vector<2xf32>
+//       CHECK:   %[[L1SH0:.*]] = vector.shuffle %[[L0SH0]], %[[L0SH1]] [0, 1, 2, 3] : vector<2xf32>, vector<2xf32>
+//       CHECK:   %[[L2SH0:.*]] = vector.shuffle %[[L0SH2]], %[[L0SH2]] [0, 1, -1, -1] : vector<2xf32>, vector<2xf32>
+//       CHECK:   %[[L3SH0:.*]] = vector.shuffle %[[L1SH0]], %[[L2SH0]] [0, 1, 2, 3, 4, 5] : vector<4xf32>, vector<4xf32>
+//       CHECK:   return %[[L3SH0]] : vector<6xf32>
+
+
+
+



More information about the Mlir-commits mailing list