[Mlir-commits] [mlir] [mlir][Vector] Add `vector.shuffle` tree transformation (PR #145740)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jun 25 10:03:45 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Diego Caballero (dcaballe)

<details>
<summary>Changes</summary>

This PR adds a new transformation that turns sequences of `vector.to_elements` and `vector.from_elements` into a binary tree of `vector.shuffle` operations. 
(Related RFC: https://discourse.llvm.org/t/rfc-adding-vector-to-elements-op-to-the-vector-dialect/86779).

Example:

```
  %0:4 = vector.to_elements %a : vector<4xf32>
  %1:4 = vector.to_elements %b : vector<4xf32>
  %2:4 = vector.to_elements %c : vector<4xf32>
  %3 = vector.from_elements %0#<!-- -->0, %0#<!-- -->1, %0#<!-- -->2, %0#<!-- -->3,
                            %1#<!-- -->0, %1#<!-- -->1, %1#<!-- -->2, %1#<!-- -->3,
                            %2#<!-- -->0, %2#<!-- -->1, %2#<!-- -->2, %2#<!-- -->3 : vector<12xf32>

==>

  %0 = vector.shuffle %a, %b [0, 1, 2, 3, 4, 5, 6, 7] : vector<4xf32>, vector<4xf32>
  %1 = vector.shuffle %c, %c [0, 1, 2, 3, -1, -1, -1, -1] : vector<4xf32>, vector<4xf32>
  %2 = vector.shuffle %0, %1 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] : vector<8xf32>, vector<8xf32>
```

The algorithm leverages the structured extraction/insertion information of `vector.to_elements` and `vector.from_elements` operations and builds a set of intervals to determine the vector length that should be used at each level of the tree to combine the level inputs in pairs.

There are a few improvements that can be implemented in the future, such as shuffle mask compression to avoid unnecessarily large vector lengths with poison values, but I decided to keep things "simpler" and spend more time documenting the different steps of the algorithm so that people can follow along.

---

Patch is 62.71 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/145740.diff


6 Files Affected:

- (modified) mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h (+7) 
- (modified) mlir/include/mlir/Dialect/Vector/Transforms/Passes.h (+1) 
- (modified) mlir/include/mlir/Dialect/Vector/Transforms/Passes.td (+5) 
- (modified) mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt (+1) 
- (added) mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp (+692) 
- (added) mlir/test/Dialect/Vector/vector-tofrom-elements-to-shuffle-tree-transforms.mlir (+329) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 14cff4ff893b5..6761cd65c5009 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -297,6 +297,13 @@ void populateVectorBitCastLoweringPatterns(RewritePatternSet &patterns,
 /// n > 1.
 void populateVectorRankReducingFMAPattern(RewritePatternSet &patterns);
 
+/// Populate patterns to rewrite sequences of `vector.to_elements` +
+/// `vector.from_elements` operations into a tree of `vector.shuffle`
+/// operations.
+void populateVectorToFromElementsToShuffleTreePatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit = 1);
+
 } // namespace vector
 } // namespace mlir
+
 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
index 5667f4fa95ace..959c2fbf31f1a 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
@@ -9,6 +9,7 @@
 #ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_
 #define MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES_H_
 
+#include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
 #include "mlir/Pass/Pass.h"
 
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
index 7436998749791..9431a4d8e240f 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.td
@@ -34,4 +34,9 @@ def LowerVectorMultiReduction : Pass<"lower-vector-multi-reduction", "func::Func
   ];
 }
 
+def LowerVectorToFromElementsToShuffleTree
+  : Pass<"lower-vector-to-from-elements-to-shuffle-tree", "func::FuncOp"> {
+  let summary = "Lower `vector.to_elements` and `vector.from_elements` to a tree of `vector.shuffle` operations";
+}
+
 #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 8ca5cb6c6dfab..9e287fc109990 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
   LowerVectorScan.cpp
   LowerVectorShapeCast.cpp
   LowerVectorStep.cpp
+  LowerVectorToFromElementsToShuffleTree.cpp
   LowerVectorTransfer.cpp
   LowerVectorTranspose.cpp
   SubsetOpInterfaceImpl.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
new file mode 100644
index 0000000000000..53728d6dbe2a3
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
@@ -0,0 +1,692 @@
+//===- VectorShuffleTreeBuilder.cpp ----- Vector shuffle tree builder -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements pattern rewrites to lower sequences of
+// `vector.to_elements` and `vector.from_elements` operations into a tree of
+// `vector.shuffle` operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Transforms/Passes.h"
+#include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace mlir {
+namespace vector {
+
+#define GEN_PASS_DEF_LOWERVECTORTOFROMELEMENTSTOSHUFFLETREE
+#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
+
+} // namespace vector
+} // namespace mlir
+
+#define DEBUG_TYPE "lower-vector-to-from-elements-to-shuffle-tree"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+namespace {
+
+// Indentation unit for debug output formatting.
+constexpr unsigned kIndScale = 2;
+
+/// Represents a closed interval of elements (e.g., [0, 7] = 8 elements).
+using Interval = std::pair<unsigned, unsigned>;
+// Sentinel value for uninitialized intervals.
+constexpr unsigned kMaxUnsigned = std::numeric_limits<unsigned>::max();
+
+/// The VectorShuffleTreeBuilder builds a balanced binary tree of
+/// `vector.shuffle` operations from one or more `vector.to_elements`
+/// operations feeding a single `vector.from_elements` operation.
+///
+/// The implementation generates hardware-agnostic `vector.shuffle` operations
+/// that minimize both the number of shuffle operations and the length of
+/// intermediate vectors (to the extent possible). The tree has the
+/// following properties:
+///
+///   1. Vectors are shuffled in pairs by order of appearance in
+///      the `vector.from_elements` operand list.
+///   2. Each input vector to each level is used only once.
+///   3. The number of levels in the tree is:
+///        ceil(log2(# `vector.to_elements` ops)).
+///   4. Vectors at each level of the tree have the same vector length.
+///   5. Vector positions that do not need to be shuffled are represented with
+///      poison in the shuffle mask.
+///
+/// Examples #1: Concatenation of 3x vector<4xf32> to vector<12xf32>:
+///
+///   %0:4 = vector.to_elements %a : vector<4xf32>
+///   %1:4 = vector.to_elements %b : vector<4xf32>
+///   %2:4 = vector.to_elements %c : vector<4xf32>
+///   %3 = vector.from_elements %0#0, %0#1, %0#2, %0#3, %1#0, %1#1,
+///                             %1#2, %1#3, %2#0, %2#1, %2#2, %2#3
+///                               : vector<12xf32>
+///   =>
+///
+///   %shuffle0 = vector.shuffle %a, %b [0, 1, 2, 3, 4, 5, 6, 7]
+///     : vector<4xf32>, vector<4xf32>
+///   %shuffle1 = vector.shuffle %c, %c [0, 1, 2, 3, -1, -1, -1, -1]
+///     : vector<4xf32>, vector<4xf32>
+///   %result = vector.shuffle %shuffle0, %shuffle1 [0, 1, 2, 3, 4, 5,
+///                                                  6, 7, 8, 9, 10, 11]
+///     : vector<8xf32>, vector<8xf32>
+///
+///   Comments:
+///     * The shuffle tree has two levels:
+///         - Level 1 = (%shuffle0, %shuffle1)
+///         - Level 2 = (%result)
+///     * `%a` and `%b` are shuffled first because they appear first in the
+///       `vector.from_elements` operand list (`%0#0` and `%1#0`).
+///     * `%c` is shuffled with itself because the number of
+///       `vector.from_elements` operands is odd.
+///     * The vector length for the first and second levels are 8 and 16,
+///       respectively.
+///     * `%shuffle1` uses poison values to match the vector length of its
+///       tree level (8).
+///
+///
+/// Example #2: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+///   %0:5 = vector.to_elements %a : vector<5xf32>
+///   %1:5 = vector.to_elements %b : vector<5xf32>
+///   %2:5 = vector.to_elements %c : vector<5xf32>
+///   %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+///                             %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+///   =>
+///
+///   %shuffle0 = vector.shuffle %[[C]], %[[B]] [2, 6, -1, -1, 7, 2, 0, 6]
+///     : vector<5xf32>, vector<5xf32>
+///   %shuffle1 = vector.shuffle %[[A]], %[[A]] [1, 1, -1, -1, -1, -1, 4, -1]
+///     : vector<5xf32>, vector<5xf32>
+///   %result = vector.shuffle %shuffle0, %shuffle1 [0, 1, 8, 9, 4, 5, 6, 7, 14]
+///     : vector<8xf32>, vector<8xf32>
+///
+///   Comments:
+///     * `%c` and `%b` are shuffled first because they appear first in the
+///       `vector.from_elements` operand list (`%2#2` and `%1#1`).
+///     * `%a` is shuffled with itself because the number of
+///       `vector.from_elements` operands is odd.
+///     * The vector length for the first and second levels are 8 and 9,
+///       respectively.
+///     * `%shuffle0` uses poison values to mark unused vector positions and
+///       match the vector length of its tree level (8).
+///
+/// TODO: Implement mask compression to reduce the number of intermediate poison
+/// values.
+///
+class VectorShuffleTreeBuilder {
+public:
+  VectorShuffleTreeBuilder() = delete;
+  VectorShuffleTreeBuilder(FromElementsOp fromElemOp,
+                           ArrayRef<ToElementsOp> toElemDefs);
+
+  /// Analyze the input `vector.to_elements` + `vector.from_elements` sequence
+  /// and compute the shuffle tree configuration. This method does not generate
+  /// any IR.
+  LogicalResult computeShuffleTree();
+
+  /// Materialize the shuffle tree configuration computed by
+  /// `computeShuffleTree` in the IR.
+  Value generateShuffleTree(PatternRewriter &rewriter);
+
+private:
+  // IR input information.
+  FromElementsOp fromElementsOp;
+  SmallVector<ToElementsOp> toElementsDefs;
+
+  // Shuffle tree configuration.
+  unsigned numLevels;
+  SmallVector<unsigned> vectorSizePerLevel;
+  /// Holds the range of positions in the final output that each vector input
+  /// in the tree is contributing to.
+  SmallVector<SmallVector<Interval>> inputIntervalsPerLevel;
+
+  // Utility methods to compute the shuffle tree configuration.
+  void computeInputVectorIntervals();
+  void computeOutputVectorSizePerLevel();
+
+  /// Dump the shuffle tree configuration.
+  void dump();
+};
+
+VectorShuffleTreeBuilder::VectorShuffleTreeBuilder(
+    FromElementsOp fromElemOp, ArrayRef<ToElementsOp> toElemDefs)
+    : fromElementsOp(fromElemOp), toElementsDefs(toElemDefs) {
+
+  assert(fromElementsOp && "from_elements op is required");
+  assert(!toElementsDefs.empty() && "At least one to_elements op is required");
+
+  // Duplicate the last vector if the number of `vector.to_elements` is odd to
+  // simplify the shuffle tree algorithm.
+  if (toElementsDefs.size() % 2 != 0) {
+    toElementsDefs.push_back(toElementsDefs.back());
+  }
+}
+
+// ===--------------------------------------------------------------------===//
+// Shuffle Tree Analysis Utilities.
+// ===--------------------------------------------------------------------===//
+
+/// Compute the intervals for all the input vectors in the shuffle tree. The
+/// interval of an input vector is the range of positions in the final output
+/// that the input vector contributes to.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+///   %0:5 = vector.to_elements %a : vector<5xf32>
+///   %1:5 = vector.to_elements %b : vector<5xf32>
+///   %2:5 = vector.to_elements %c : vector<5xf32>
+///   %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+///                             %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+///
+/// Level 0 has 4 inputs (%2, %1, %0, %0, the last one is duplicated to make the
+/// number of inputs even) so we compute the interval for each input vector:
+///
+///    * inputIntervalsPerLevel[0][0] = interval(%2) = [0,6]
+///    * inputIntervalsPerLevel[0][1] = interval(%1) = [1,7]
+///    * inputIntervalsPerLevel[0][2] = interval(%0) = [2,8]
+///    * inputIntervalsPerLevel[0][3] = interval(%0) = [2,8]
+///
+/// Level 1 has 2 inputs, resulting from the shuffling of %2 + %1 and %0 + %0 so
+/// we compute the intervals for each input vector to level 1 as:
+///    * inputIntervalsPerLevel[1][0] = interval(%2) U interval(%1) = [0,7]
+///    * inputIntervalsPerLevel[1][1] = interval(%0) U interval(%0) = [2,8]
+///
+void VectorShuffleTreeBuilder::computeInputVectorIntervals() {
+  // Map `vector.to_elements` ops to their ordinal position in the
+  // `vector.from_elements` operand list. Make sure duplicated
+  // `vector.to_elements` ops are mapped to the its first occurrence.
+  DenseMap<ToElementsOp, unsigned> toElementsToInputOrdinal;
+  for (const auto &[idx, toElementsOp] : llvm::enumerate(toElementsDefs))
+    toElementsToInputOrdinal.insert({toElementsOp, idx});
+
+  // Compute intervals for each input vector in the shuffle tree. The first
+  // level computation is special-cased to keep the implementation simpler.
+
+  SmallVector<Interval> firstLevelIntervals(toElementsDefs.size(),
+                                            {kMaxUnsigned, kMaxUnsigned});
+
+  for (const auto &[idx, element] :
+       llvm::enumerate(fromElementsOp.getElements())) {
+    auto toElementsOp = cast<ToElementsOp>(element.getDefiningOp());
+    unsigned inputIdx = toElementsToInputOrdinal[toElementsOp];
+    Interval &currentInterval = firstLevelIntervals[inputIdx];
+
+    // Set lower bound to the first occurrence of the `vector.to_elements`.
+    if (currentInterval.first == kMaxUnsigned)
+      currentInterval.first = idx;
+
+    // Set upper bound to the last occurrence of the `vector.to_elements`.
+    currentInterval.second = idx;
+  }
+
+  // If the number of `vector.to_elements` is odd and the last op was
+  // duplicated, the interval for the duplicated op was not computed in the
+  // previous step as all the input occurrences were mapped to the original op.
+  // We copy the interval of the original op to the interval of the duplicated
+  // op manually.
+  if (firstLevelIntervals.back().second == kMaxUnsigned)
+    firstLevelIntervals.back() = *std::prev(firstLevelIntervals.end(), 2);
+
+  inputIntervalsPerLevel.push_back(std::move(firstLevelIntervals));
+
+  // Compute intervals for the remaining levels.
+  unsigned outputNumElements =
+      cast<VectorType>(fromElementsOp.getResult().getType()).getNumElements();
+  for (unsigned level = 1; level < numLevels; ++level) {
+    const auto &prevLevelIntervals = inputIntervalsPerLevel[level - 1];
+    SmallVector<Interval> currentLevelIntervals(
+        llvm::divideCeil(prevLevelIntervals.size(), 2),
+        {kMaxUnsigned, kMaxUnsigned});
+
+    for (size_t inputIdx = 0; inputIdx < currentLevelIntervals.size();
+         ++inputIdx) {
+      auto &interval = currentLevelIntervals[inputIdx];
+      const auto &prevLhsInterval = prevLevelIntervals[inputIdx * 2];
+      const auto &prevRhsInterval = prevLevelIntervals[inputIdx * 2 + 1];
+
+      // The interval of a vector at the current level is the union of the
+      // intervals of the two input vectors from the previous level being
+      // shuffled at this level.
+      interval.first = std::min(prevLhsInterval.first, prevRhsInterval.first);
+      interval.second =
+          std::min(std::max(prevLhsInterval.second, prevRhsInterval.second),
+                   outputNumElements - 1);
+    }
+
+    inputIntervalsPerLevel.push_back(std::move(currentLevelIntervals));
+  }
+}
+
+/// Compute the uniform output vector size for each level of the shuffle tree,
+/// given the intervals of the input vectors at that level. The output vector
+/// size of a level is the size of the widest interval resulting from shuffling
+/// each pair of input vectors.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+///   Intervals:
+///     * Level 0: [0,6], [1,7], [2,8], [2,8]
+///     * Level 1: [0,7], [2,8]
+///
+///   Vector sizes:
+///     * Level 0: max(size_of([0,6] U [1,7] = [0,7]) = 8,
+///                    size_of([2,8] U [2,8] = [2,8]) = 7) = 8
+///
+///     * Level 1: max(size_of([0,7] U [2,8] = [0,8]) = 9) = 9
+///
+void VectorShuffleTreeBuilder::computeOutputVectorSizePerLevel() {
+  // Compute vector size for each level.
+  for (unsigned level = 0; level < numLevels; ++level) {
+    const auto &currentLevelIntervals = inputIntervalsPerLevel[level];
+    unsigned currentVectorSize = 1;
+    for (size_t i = 0; i < currentLevelIntervals.size(); i += 2) {
+      const auto &lhsInterval = currentLevelIntervals[i];
+      const auto &rhsInterval = currentLevelIntervals[i + 1];
+      unsigned combinedIntervalSize =
+          std::max(lhsInterval.second, rhsInterval.second) - lhsInterval.first +
+          1;
+      currentVectorSize = std::max(currentVectorSize, combinedIntervalSize);
+    }
+    vectorSizePerLevel[level] = currentVectorSize;
+  }
+}
+
+void VectorShuffleTreeBuilder::dump() {
+  LLVM_DEBUG({
+    unsigned indLv = 0;
+
+    llvm::dbgs() << "VectorShuffleTreeBuilder Configuration:\n";
+    ++indLv;
+    llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Inputs:\n";
+    ++indLv;
+    for (const auto &toElementsOp : toElementsDefs)
+      llvm::dbgs() << llvm::indent(indLv, kIndScale) << toElementsOp << "\n";
+    llvm::dbgs() << llvm::indent(indLv, kIndScale) << fromElementsOp << "\n\n";
+    --indLv;
+
+    llvm::dbgs() << llvm::indent(indLv, kIndScale)
+                 << "* Total levels: " << numLevels << "\n";
+    llvm::dbgs() << llvm::indent(indLv, kIndScale)
+                 << "* Vector sizes per level: [";
+    llvm::interleaveComma(vectorSizePerLevel, llvm::dbgs());
+    llvm::dbgs() << "]\n";
+    llvm::dbgs() << llvm::indent(indLv, kIndScale)
+                 << "* Input intervals per level:\n";
+    ++indLv;
+    for (const auto &[level, intervals] :
+         llvm::enumerate(inputIntervalsPerLevel)) {
+      llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Level " << level
+                   << ": ";
+      llvm::interleaveComma(intervals, llvm::dbgs(),
+                            [](const Interval &interval) {
+                              llvm::dbgs() << "[" << interval.first << ","
+                                           << interval.second << "]";
+                            });
+      llvm::dbgs() << "\n";
+    }
+  });
+}
+
+/// Compute the shuffle tree configuration for the given `vector.to_elements` +
+/// `vector.from_elements` input sequence. This method builds a balanced binary
+/// shuffle tree that combines pairs of input vectors at each level.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+///   %0:5 = vector.to_elements %a : vector<5xf32>
+///   %1:5 = vector.to_elements %b : vector<5xf32>
+///   %2:5 = vector.to_elements %c : vector<5xf32>
+///   %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+///                             %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+///
+///   build a tree that looks like:
+///
+///          %2    %1                   %0    %0
+///            \  /                       \  /
+///  %2_1 = vector.shuffle     %0_0 = vector.shuffle
+///              \                    /
+///             %2_1_0_0 =vector.shuffle
+///
+/// The configuration comprises of computing the intervals of the input vectors
+/// at each level of the shuffle tree (i.e., %2, %1, %0, %0, %2_1, %0_0 and
+/// %2_1_0_0) and the output vector size for each level. For further details on
+/// intervals and output vector size computation, please, take a look at the
+/// corresponding utility functions.
+LogicalResult VectorShuffleTreeBuilder::computeShuffleTree() {
+  // Initialize shuffle tree information based on its size.
+  assert(toElementsDefs.size() > 1 &&
+         "At least two 'vector.to_elements' ops are required");
+  numLevels = llvm::Log2_64(toElementsDefs.size());
+  vectorSizePerLevel.resize(numLevels, 0);
+  inputIntervalsPerLevel.reserve(numLevels);
+
+  computeInputVectorIntervals();
+  computeOutputVectorSizePerLevel();
+  dump();
+
+  return success();
+}
+
+// ===--------------------------------------------------------------------===//
+// Shuffle Tree Code Generation Utilities.
+// ===--------------------------------------------------------------------===//
+
+/// Compute the permutation mask for shuffling two input `vector.to_elements`
+/// ops. The permutation mask is the mapping of the input vector elements to
+/// their final position in the output vector, relative to the intermediate
+/// output vector of the `vector.shuffle` operation combining the two inputs.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+///   %0:5 = vector.to_elements %a : vector<5xf32>
+///   %1:5 = vector.to_elements %b : vector<5xf32>
+///   %2:5 = vector.to_elements %c : vector<5xf32>
+///   %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+///                             %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+///
+///   =>
+///
+///   // Level 0, vector length = 8
+///   %2_1 = PermutationShuffleMask(%2, %1) = [2,...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/145740


More information about the Mlir-commits mailing list