[Mlir-commits] [mlir] [mlir][Vector] Add `vector.shuffle` tree transformation (PR #145740)

James Newling llvmlistbot at llvm.org
Mon Jul 7 09:16:41 PDT 2025


================
@@ -0,0 +1,692 @@
+//===- VectorShuffleTreeBuilder.cpp ----- Vector shuffle tree builder -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements pattern rewrites to lower sequences of
+// `vector.to_elements` and `vector.from_elements` operations into a tree of
+// `vector.shuffle` operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+#include "mlir/Dialect/Vector/Transforms/Passes.h"
+#include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace mlir {
+namespace vector {
+
+#define GEN_PASS_DEF_LOWERVECTORTOFROMELEMENTSTOSHUFFLETREE
+#include "mlir/Dialect/Vector/Transforms/Passes.h.inc"
+
+} // namespace vector
+} // namespace mlir
+
+#define DEBUG_TYPE "lower-vector-to-from-elements-to-shuffle-tree"
+
+using namespace mlir;
+using namespace mlir::vector;
+
+namespace {
+
+// Indentation unit for debug output formatting.
+constexpr unsigned kIndScale = 2;
+
+/// Represents a closed interval of elements (e.g., [0, 7] = 8 elements).
+using Interval = std::pair<unsigned, unsigned>;
+// Sentinel value for uninitialized intervals.
+constexpr unsigned kMaxUnsigned = std::numeric_limits<unsigned>::max();
+
+/// The VectorShuffleTreeBuilder builds a balanced binary tree of
+/// `vector.shuffle` operations from one or more `vector.to_elements`
+/// operations feeding a single `vector.from_elements` operation.
+///
+/// The implementation generates hardware-agnostic `vector.shuffle` operations
+/// that minimize both the number of shuffle operations and the length of
+/// intermediate vectors (to the extent possible). The tree has the
+/// following properties:
+///
+///   1. Vectors are shuffled in pairs by order of appearance in
+///      the `vector.from_elements` operand list.
+///   2. Each input vector to each level is used only once.
+///   3. The number of levels in the tree is:
+///        ceil(log2(# `vector.to_elements` ops)).
+///   4. Vectors at each level of the tree have the same vector length.
+///   5. Vector positions that do not need to be shuffled are represented with
+///      poison in the shuffle mask.
+///
+/// Examples #1: Concatenation of 3x vector<4xf32> to vector<12xf32>:
+///
+///   %0:4 = vector.to_elements %a : vector<4xf32>
+///   %1:4 = vector.to_elements %b : vector<4xf32>
+///   %2:4 = vector.to_elements %c : vector<4xf32>
+///   %3 = vector.from_elements %0#0, %0#1, %0#2, %0#3, %1#0, %1#1,
+///                             %1#2, %1#3, %2#0, %2#1, %2#2, %2#3
+///                               : vector<12xf32>
+///   =>
+///
+///   %shuffle0 = vector.shuffle %a, %b [0, 1, 2, 3, 4, 5, 6, 7]
+///     : vector<4xf32>, vector<4xf32>
+///   %shuffle1 = vector.shuffle %c, %c [0, 1, 2, 3, -1, -1, -1, -1]
+///     : vector<4xf32>, vector<4xf32>
+///   %result = vector.shuffle %shuffle0, %shuffle1 [0, 1, 2, 3, 4, 5,
+///                                                  6, 7, 8, 9, 10, 11]
+///     : vector<8xf32>, vector<8xf32>
+///
+///   Comments:
+///     * The shuffle tree has two levels:
+///         - Level 1 = (%shuffle0, %shuffle1)
+///         - Level 2 = (%result)
+///     * `%a` and `%b` are shuffled first because they appear first in the
+///       `vector.from_elements` operand list (`%0#0` and `%1#0`).
+///     * `%c` is shuffled with itself because the number of
+///       `vector.from_elements` operands is odd.
+///     * The vector length for the first and second levels are 8 and 16,
+///       respectively.
+///     * `%shuffle1` uses poison values to match the vector length of its
+///       tree level (8).
+///
+///
+/// Example #2: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+///   %0:5 = vector.to_elements %a : vector<5xf32>
+///   %1:5 = vector.to_elements %b : vector<5xf32>
+///   %2:5 = vector.to_elements %c : vector<5xf32>
+///   %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+///                             %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+///   =>
+///
+///   %shuffle0 = vector.shuffle %[[C]], %[[B]] [2, 6, -1, -1, 7, 2, 0, 6]
+///     : vector<5xf32>, vector<5xf32>
+///   %shuffle1 = vector.shuffle %[[A]], %[[A]] [1, 1, -1, -1, -1, -1, 4, -1]
+///     : vector<5xf32>, vector<5xf32>
+///   %result = vector.shuffle %shuffle0, %shuffle1 [0, 1, 8, 9, 4, 5, 6, 7, 14]
+///     : vector<8xf32>, vector<8xf32>
+///
+///   Comments:
+///     * `%c` and `%b` are shuffled first because they appear first in the
+///       `vector.from_elements` operand list (`%2#2` and `%1#1`).
+///     * `%a` is shuffled with itself because the number of
+///       `vector.from_elements` operands is odd.
+///     * The vector length for the first and second levels are 8 and 9,
+///       respectively.
+///     * `%shuffle0` uses poison values to mark unused vector positions and
+///       match the vector length of its tree level (8).
+///
+/// TODO: Implement mask compression to reduce the number of intermediate poison
+/// values.
+///
+class VectorShuffleTreeBuilder {
+public:
+  VectorShuffleTreeBuilder() = delete;
+  VectorShuffleTreeBuilder(FromElementsOp fromElemOp,
+                           ArrayRef<ToElementsOp> toElemDefs);
+
+  /// Analyze the input `vector.to_elements` + `vector.from_elements` sequence
+  /// and compute the shuffle tree configuration. This method does not generate
+  /// any IR.
+  LogicalResult computeShuffleTree();
+
+  /// Materialize the shuffle tree configuration computed by
+  /// `computeShuffleTree` in the IR.
+  Value generateShuffleTree(PatternRewriter &rewriter);
+
+private:
+  // IR input information.
+  FromElementsOp fromElementsOp;
+  SmallVector<ToElementsOp> toElementsDefs;
+
+  // Shuffle tree configuration.
+  unsigned numLevels;
+  SmallVector<unsigned> vectorSizePerLevel;
+  /// Holds the range of positions in the final output that each vector input
+  /// in the tree is contributing to.
+  SmallVector<SmallVector<Interval>> inputIntervalsPerLevel;
+
+  // Utility methods to compute the shuffle tree configuration.
+  void computeInputVectorIntervals();
+  void computeOutputVectorSizePerLevel();
+
+  /// Dump the shuffle tree configuration.
+  void dump();
+};
+
+VectorShuffleTreeBuilder::VectorShuffleTreeBuilder(
+    FromElementsOp fromElemOp, ArrayRef<ToElementsOp> toElemDefs)
+    : fromElementsOp(fromElemOp), toElementsDefs(toElemDefs) {
+
+  assert(fromElementsOp && "from_elements op is required");
+  assert(!toElementsDefs.empty() && "At least one to_elements op is required");
+
+  // Duplicate the last vector if the number of `vector.to_elements` is odd to
+  // simplify the shuffle tree algorithm.
+  if (toElementsDefs.size() % 2 != 0) {
+    toElementsDefs.push_back(toElementsDefs.back());
+  }
+}
+
+// ===--------------------------------------------------------------------===//
+// Shuffle Tree Analysis Utilities.
+// ===--------------------------------------------------------------------===//
+
+/// Compute the intervals for all the input vectors in the shuffle tree. The
+/// interval of an input vector is the range of positions in the final output
+/// that the input vector contributes to.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+///   %0:5 = vector.to_elements %a : vector<5xf32>
+///   %1:5 = vector.to_elements %b : vector<5xf32>
+///   %2:5 = vector.to_elements %c : vector<5xf32>
+///   %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+///                             %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+///
+/// Level 0 has 4 inputs (%2, %1, %0, %0, the last one is duplicated to make the
+/// number of inputs even) so we compute the interval for each input vector:
+///
+///    * inputIntervalsPerLevel[0][0] = interval(%2) = [0,6]
+///    * inputIntervalsPerLevel[0][1] = interval(%1) = [1,7]
+///    * inputIntervalsPerLevel[0][2] = interval(%0) = [2,8]
+///    * inputIntervalsPerLevel[0][3] = interval(%0) = [2,8]
+///
+/// Level 1 has 2 inputs, resulting from the shuffling of %2 + %1 and %0 + %0 so
+/// we compute the intervals for each input vector to level 1 as:
+///    * inputIntervalsPerLevel[1][0] = interval(%2) U interval(%1) = [0,7]
+///    * inputIntervalsPerLevel[1][1] = interval(%0) U interval(%0) = [2,8]
+///
+void VectorShuffleTreeBuilder::computeInputVectorIntervals() {
+  // Map `vector.to_elements` ops to their ordinal position in the
+  // `vector.from_elements` operand list. Make sure duplicated
+  // `vector.to_elements` ops are mapped to the its first occurrence.
+  DenseMap<ToElementsOp, unsigned> toElementsToInputOrdinal;
+  for (const auto &[idx, toElementsOp] : llvm::enumerate(toElementsDefs))
+    toElementsToInputOrdinal.insert({toElementsOp, idx});
+
+  // Compute intervals for each input vector in the shuffle tree. The first
+  // level computation is special-cased to keep the implementation simpler.
+
+  SmallVector<Interval> firstLevelIntervals(toElementsDefs.size(),
+                                            {kMaxUnsigned, kMaxUnsigned});
+
+  for (const auto &[idx, element] :
+       llvm::enumerate(fromElementsOp.getElements())) {
+    auto toElementsOp = cast<ToElementsOp>(element.getDefiningOp());
+    unsigned inputIdx = toElementsToInputOrdinal[toElementsOp];
+    Interval &currentInterval = firstLevelIntervals[inputIdx];
+
+    // Set lower bound to the first occurrence of the `vector.to_elements`.
+    if (currentInterval.first == kMaxUnsigned)
+      currentInterval.first = idx;
+
+    // Set upper bound to the last occurrence of the `vector.to_elements`.
+    currentInterval.second = idx;
+  }
+
+  // If the number of `vector.to_elements` is odd and the last op was
+  // duplicated, the interval for the duplicated op was not computed in the
+  // previous step as all the input occurrences were mapped to the original op.
+  // We copy the interval of the original op to the interval of the duplicated
+  // op manually.
+  if (firstLevelIntervals.back().second == kMaxUnsigned)
+    firstLevelIntervals.back() = *std::prev(firstLevelIntervals.end(), 2);
+
+  inputIntervalsPerLevel.push_back(std::move(firstLevelIntervals));
+
+  // Compute intervals for the remaining levels.
+  unsigned outputNumElements =
+      cast<VectorType>(fromElementsOp.getResult().getType()).getNumElements();
+  for (unsigned level = 1; level < numLevels; ++level) {
+    const auto &prevLevelIntervals = inputIntervalsPerLevel[level - 1];
+    SmallVector<Interval> currentLevelIntervals(
+        llvm::divideCeil(prevLevelIntervals.size(), 2),
+        {kMaxUnsigned, kMaxUnsigned});
+
+    for (size_t inputIdx = 0; inputIdx < currentLevelIntervals.size();
+         ++inputIdx) {
+      auto &interval = currentLevelIntervals[inputIdx];
+      const auto &prevLhsInterval = prevLevelIntervals[inputIdx * 2];
+      const auto &prevRhsInterval = prevLevelIntervals[inputIdx * 2 + 1];
+
+      // The interval of a vector at the current level is the union of the
+      // intervals of the two input vectors from the previous level being
+      // shuffled at this level.
+      interval.first = std::min(prevLhsInterval.first, prevRhsInterval.first);
+      interval.second =
+          std::min(std::max(prevLhsInterval.second, prevRhsInterval.second),
+                   outputNumElements - 1);
+    }
+
+    inputIntervalsPerLevel.push_back(std::move(currentLevelIntervals));
+  }
+}
+
+/// Compute the uniform output vector size for each level of the shuffle tree,
+/// given the intervals of the input vectors at that level. The output vector
+/// size of a level is the size of the widest interval resulting from shuffling
+/// each pair of input vectors.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+///   Intervals:
+///     * Level 0: [0,6], [1,7], [2,8], [2,8]
+///     * Level 1: [0,7], [2,8]
+///
+///   Vector sizes:
+///     * Level 0: max(size_of([0,6] U [1,7] = [0,7]) = 8,
+///                    size_of([2,8] U [2,8] = [2,8]) = 7) = 8
+///
+///     * Level 1: max(size_of([0,7] U [2,8] = [0,8]) = 9) = 9
+///
+void VectorShuffleTreeBuilder::computeOutputVectorSizePerLevel() {
+  // Compute vector size for each level.
+  for (unsigned level = 0; level < numLevels; ++level) {
+    const auto &currentLevelIntervals = inputIntervalsPerLevel[level];
+    unsigned currentVectorSize = 1;
+    for (size_t i = 0; i < currentLevelIntervals.size(); i += 2) {
+      const auto &lhsInterval = currentLevelIntervals[i];
+      const auto &rhsInterval = currentLevelIntervals[i + 1];
+      unsigned combinedIntervalSize =
+          std::max(lhsInterval.second, rhsInterval.second) - lhsInterval.first +
+          1;
+      currentVectorSize = std::max(currentVectorSize, combinedIntervalSize);
+    }
+    vectorSizePerLevel[level] = currentVectorSize;
+  }
+}
+
+void VectorShuffleTreeBuilder::dump() {
+  LLVM_DEBUG({
+    unsigned indLv = 0;
+
+    llvm::dbgs() << "VectorShuffleTreeBuilder Configuration:\n";
+    ++indLv;
+    llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Inputs:\n";
+    ++indLv;
+    for (const auto &toElementsOp : toElementsDefs)
+      llvm::dbgs() << llvm::indent(indLv, kIndScale) << toElementsOp << "\n";
+    llvm::dbgs() << llvm::indent(indLv, kIndScale) << fromElementsOp << "\n\n";
+    --indLv;
+
+    llvm::dbgs() << llvm::indent(indLv, kIndScale)
+                 << "* Total levels: " << numLevels << "\n";
+    llvm::dbgs() << llvm::indent(indLv, kIndScale)
+                 << "* Vector sizes per level: [";
+    llvm::interleaveComma(vectorSizePerLevel, llvm::dbgs());
+    llvm::dbgs() << "]\n";
+    llvm::dbgs() << llvm::indent(indLv, kIndScale)
+                 << "* Input intervals per level:\n";
+    ++indLv;
+    for (const auto &[level, intervals] :
+         llvm::enumerate(inputIntervalsPerLevel)) {
+      llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Level " << level
+                   << ": ";
+      llvm::interleaveComma(intervals, llvm::dbgs(),
+                            [](const Interval &interval) {
+                              llvm::dbgs() << "[" << interval.first << ","
+                                           << interval.second << "]";
+                            });
+      llvm::dbgs() << "\n";
+    }
+  });
+}
+
+/// Compute the shuffle tree configuration for the given `vector.to_elements` +
+/// `vector.from_elements` input sequence. This method builds a balanced binary
+/// shuffle tree that combines pairs of input vectors at each level.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+///   %0:5 = vector.to_elements %a : vector<5xf32>
+///   %1:5 = vector.to_elements %b : vector<5xf32>
+///   %2:5 = vector.to_elements %c : vector<5xf32>
+///   %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+///                             %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+///
+///   build a tree that looks like:
+///
+///          %2    %1                   %0    %0
+///            \  /                       \  /
+///  %2_1 = vector.shuffle     %0_0 = vector.shuffle
+///              \                    /
+///             %2_1_0_0 =vector.shuffle
+///
+/// The configuration comprises of computing the intervals of the input vectors
+/// at each level of the shuffle tree (i.e., %2, %1, %0, %0, %2_1, %0_0 and
+/// %2_1_0_0) and the output vector size for each level. For further details on
+/// intervals and output vector size computation, please, take a look at the
+/// corresponding utility functions.
+LogicalResult VectorShuffleTreeBuilder::computeShuffleTree() {
+  // Initialize shuffle tree information based on its size.
+  assert(toElementsDefs.size() > 1 &&
+         "At least two 'vector.to_elements' ops are required");
+  numLevels = llvm::Log2_64(toElementsDefs.size());
+  vectorSizePerLevel.resize(numLevels, 0);
+  inputIntervalsPerLevel.reserve(numLevels);
+
+  computeInputVectorIntervals();
+  computeOutputVectorSizePerLevel();
+  dump();
+
+  return success();
+}
+
+// ===--------------------------------------------------------------------===//
+// Shuffle Tree Code Generation Utilities.
+// ===--------------------------------------------------------------------===//
+
+/// Compute the permutation mask for shuffling two input `vector.to_elements`
+/// ops. The permutation mask is the mapping of the input vector elements to
+/// their final position in the output vector, relative to the intermediate
+/// output vector of the `vector.shuffle` operation combining the two inputs.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+///   %0:5 = vector.to_elements %a : vector<5xf32>
+///   %1:5 = vector.to_elements %b : vector<5xf32>
+///   %2:5 = vector.to_elements %c : vector<5xf32>
+///   %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+///                             %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+///
+///   =>
+///
+///   // Level 0, vector length = 8
+///   %2_1 = PermutationShuffleMask(%2, %1) = [2, 6, -1, -1, 7, 2, 0, 6]
+///   %0_0 = PermutationShuffleMask(%0, %0) = [1, 1, -1, -1, -1, -1, 4, -1]
+///
+/// TODO: Implement mask compression.
+static SmallVector<int64_t> computePermutationShuffleMask(
+    ToElementsOp toElementOp0, const Interval &interval0,
+    ToElementsOp toElementOp1, const Interval &interval1,
+    FromElementsOp fromElementsOp, unsigned outputVectorSize) {
+  SmallVector<int64_t> mask(outputVectorSize, ShuffleOp::kPoisonIndex);
+  unsigned inputVectorSize =
+      toElementOp0.getSource().getType().getNumElements();
+
+  for (const auto &[inputIdx, element] :
+       llvm::enumerate(fromElementsOp.getElements())) {
+    auto currentToElemOp = cast<ToElementsOp>(element.getDefiningOp());
+    // Match `vector.from_elements` operands to the two input ops.
+    if (currentToElemOp != toElementOp0 && currentToElemOp != toElementOp1)
+      continue;
+
+    // The permutation value for a particular operand is the ordinal position of
+    // the operand in the `vector.to_elements` list of results.
+    unsigned permVal = cast<OpResult>(element).getResultNumber();
+    unsigned maskIdx = inputIdx;
+
+    // The mask index is the ordinal position of the operand in
+    // `vector.from_elements` operand list. We make this position relative to
+    // the interval of the output vector resulting from combining the two
+    // input vectors.
+    if (currentToElemOp == toElementOp0) {
+      maskIdx -= interval0.first;
+    } else {
+      // currentToElemOp == toElementOp1
+      unsigned intervalOffset = interval1.first - interval0.first;
+      maskIdx += intervalOffset - interval1.first;
+      permVal += inputVectorSize;
+    }
+
+    mask[maskIdx] = permVal;
+  }
+
+  LLVM_DEBUG({
+    unsigned indLv = 1;
+    llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Permutation mask: [";
+    llvm::interleaveComma(mask, llvm::dbgs());
+    llvm::dbgs() << "]\n";
+    ++indLv;
+    llvm::dbgs() << llvm::indent(indLv, kIndScale)
+                 << "* Combining: " << toElementOp0 << " and " << toElementOp1
+                 << "\n";
+  });
+
+  return mask;
+}
+
+/// Compute the propagation shuffle mask for combining two intermediate shuffle
+/// operations of the tree. The propagation shuffle mask is the mapping of the
+/// intermediate vector elements, which have already been shuffled to their
+/// relative output position using the mask generated by
+/// `computePermutationShuffleMask`, to their next position in the tree.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+///   %0:5 = vector.to_elements %a : vector<5xf32>
+///   %1:5 = vector.to_elements %b : vector<5xf32>
+///   %2:5 = vector.to_elements %c : vector<5xf32>
+///   %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+///                             %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+///
+///   // Level 0, vector length = 8
+///   %2_1 = PermutationShuffleMask(%2, %1) = [2, 6, -1, -1, 7, 2, 0, 6]
+///   %0_0 = PermutationShuffleMask(%0, %0) = [1, 1, -1, -1, -1, -1, 4, -1]
+///
+///   =>
+///
+///   // Level 1, vector length = 9
+///   PropagationShuffleMask(%2_1, %0_0) = [0, 1, 8, 9, 4, 5, 6, 7, 14]
+///
+/// TODO: Implement mask compression.
+///
+static SmallVector<int64_t> computePropagationShuffleMask(
+    ShuffleOp lhsShuffleOp, const Interval &lhsInterval, ShuffleOp rhsShuffleOp,
+    const Interval &rhsInterval, unsigned outputVectorSize) {
+  ArrayRef<int64_t> lhsShuffleMask = lhsShuffleOp.getMask();
+  ArrayRef<int64_t> rhsShuffleMask = rhsShuffleOp.getMask();
+  unsigned inputVectorSize = lhsShuffleMask.size();
+  assert(inputVectorSize == rhsShuffleMask.size() &&
+         "Expected both shuffle masks to have the same size");
+
+  unsigned lhsRhsOffset = rhsInterval.first - lhsInterval.first;
+  SmallVector<int64_t> mask(outputVectorSize, ShuffleOp::kPoisonIndex);
+
+  // Propagate any element from the input mask that is not poison. For the RHS
+  // input vector, the mask index is offset by the offset between the two
+  // intervals of the input vectors.
+  for (unsigned i = 0; i < inputVectorSize; ++i) {
+    if (lhsShuffleMask[i] != ShuffleOp::kPoisonIndex)
+      mask[i] = i;
+
+    unsigned rhsIdx = i + lhsRhsOffset;
+    if (rhsShuffleMask[i] != ShuffleOp::kPoisonIndex) {
+      assert(rhsIdx < outputVectorSize && "RHS index out of bounds");
+      assert(mask[rhsIdx] == ShuffleOp::kPoisonIndex && "mask already set");
+      mask[rhsIdx] = i + inputVectorSize;
+    }
+  }
+
+  LLVM_DEBUG({
+    unsigned indLv = 1;
+    llvm::dbgs() << llvm::indent(indLv, kIndScale)
+                 << "* Propagation shuffle mask computation:\n";
+    ++indLv;
+    llvm::dbgs() << llvm::indent(indLv, kIndScale)
+                 << "* LHS shuffle op: " << lhsShuffleOp << "\n";
+    llvm::dbgs() << llvm::indent(indLv, kIndScale)
+                 << "* RHS shuffle op: " << rhsShuffleOp << "\n";
+    llvm::dbgs() << llvm::indent(indLv, kIndScale) << "* Result mask: [";
+    llvm::interleaveComma(mask, llvm::dbgs());
+    llvm::dbgs() << "]\n";
+  });
+
+  return mask;
+}
+
+/// Materialize the pre-computed shuffle tree configuration in the IR by
+/// generating the corresponding `vector.shuffle` ops.
+///
+/// Example: Arbitrary shuffling of 3x vector<5xf32> to vector<9xf32>:
+///
+///   %0:5 = vector.to_elements %a : vector<5xf32>
+///   %1:5 = vector.to_elements %b : vector<5xf32>
+///   %2:5 = vector.to_elements %c : vector<5xf32>
+///   %3 = vector.from_elements %2#2, %1#1, %0#1, %0#1, %1#2,
+///                             %2#2, %2#0, %1#1, %0#4 : vector<9xf32>
+///
+///   with the pre-computed shuffle tree configuration:
+///
+///     * Vector sizes per level: [8, 9]
+///     * Input intervals per level:
+///       * Level 0: [0,6], [1,7], [2,8], [2,8]
+///       * Level 1: [0,7], [2,8]
+///
+///   =>
+///
+///    %0 = vector.shuffle %arg2, %arg1 [2, 6, -1, -1, 7, 2, 0, 6]
+///        : vector<5xf32>, vector<5xf32>
+///    %1 = vector.shuffle %arg0, %arg0 [1, 1, -1, -1, -1, -1, 4, -1]
+///        : vector<5xf32>, vector<5xf32>
+///    %2 = vector.shuffle %0, %1 [0, 1, 8, 9, 4, 5, 6, 7, 14]
+///        : vector<8xf32>, vector<8xf32>
+///
+/// The code generation comprises of combining pairs of input vectors for each
+/// level of the tree, using the pre-computed per tree level intervals and
+/// vector sizes. The algorithm generates two kinds of shuffle masks:
+/// permutation masks and propagation masks. Permutation masks are computed for
+/// the first level of the tree and permute the input vector elements to their
+/// relative position in the final output. Propagation masks are computed for
+/// subsequent levels and propagate the elements to the next level without
+/// permutation. For further details on the shuffle mask computation, please,
+/// take a look at the corresponding `computePermutationShuffleMask` and
+/// `computePropagationShuffleMask` functions.
+///
+Value VectorShuffleTreeBuilder::generateShuffleTree(PatternRewriter &rewriter) {
+  LLVM_DEBUG(llvm::dbgs() << "VectorShuffleTreeBuilder Code Generation:\n");
+
+  // Initialize work list with the `vector.to_elements` sources.
+  SmallVector<Value> levelInputs;
+  llvm::transform(
+      toElementsDefs, std::back_inserter(levelInputs),
+      [](ToElementsOp toElementsOp) { return toElementsOp.getSource(); });
+
+  // Build shuffle tree by combining pairs of vectors.
+  Location loc = fromElementsOp.getLoc();
+  unsigned currentLevel = 0;
+  for (const auto &[levelVectorSize, inputIntervals] :
+       llvm::zip_equal(vectorSizePerLevel, inputIntervalsPerLevel)) {
+    LLVM_DEBUG(llvm::dbgs()
+               << llvm::indent(1, kIndScale) << "* Processing level "
+               << currentLevel << " (vector size: " << levelVectorSize
+               << ", # inputs: " << levelInputs.size() << ")\n");
+
+    // Process level input vectors in pairs.
+    SmallVector<Value> levelOutputs;
+    for (size_t i = 0; i < levelInputs.size(); i += 2) {
+      Value lhsVector = levelInputs[i];
+      Value rhsVector = levelInputs[i + 1];
+      const Interval &lhsInterval = inputIntervals[i];
+      const Interval &rhsInterval = inputIntervals[i + 1];
+
+      // For the first level of the tree, permute the vector elements to their
+      // relative position in the final output. For subsequent levels, we
+      // propagate the elements to the next level without permutation.
+      SmallVector<int64_t> shuffleMask;
+      if (currentLevel == 0) {
+        shuffleMask = computePermutationShuffleMask(
+            toElementsDefs[i], lhsInterval, toElementsDefs[i + 1], rhsInterval,
+            fromElementsOp, levelVectorSize);
+      } else {
+        auto lhsShuffleOp = cast<ShuffleOp>(lhsVector.getDefiningOp());
+        auto rhsShuffleOp = cast<ShuffleOp>(rhsVector.getDefiningOp());
+        shuffleMask = computePropagationShuffleMask(lhsShuffleOp, lhsInterval,
+                                                    rhsShuffleOp, rhsInterval,
+                                                    levelVectorSize);
+      }
+
+      Value shuffleVal = rewriter.create<vector::ShuffleOp>(
+          loc, lhsVector, rhsVector, shuffleMask);
+      levelOutputs.push_back(shuffleVal);
+    }
+
+    levelInputs = std::move(levelOutputs);
+    ++currentLevel;
+  }
+
+  assert(levelInputs.size() == 1 && "Should have exactly one result");
+  return levelInputs.front();
+}
+
+/// Gather and unique all the `vector.to_elements` operations that feed the
+/// `vector.from_elements` operation. The `vector.to_elements` operations are
+/// returned in order of appearance in the `vector.from_elements`'s operand
+/// list.
+static LogicalResult
+getToElementsDefiningOps(FromElementsOp fromElementsOp,
+                         SmallVectorImpl<ToElementsOp> &toElementsDefs) {
+  SetVector<ToElementsOp> toElementsDefsSet;
+  for (Value element : fromElementsOp.getElements()) {
+    auto toElementsOp = element.getDefiningOp<ToElementsOp>();
+    if (!toElementsOp)
+      return failure();
+    toElementsDefsSet.insert(toElementsOp);
+  }
+
+  toElementsDefs.assign(toElementsDefsSet.begin(), toElementsDefsSet.end());
+  return success();
+}
+
+/// Pass to rewrite `vector.to_elements` + `vector.from_elements` sequences into
+/// a tree of `vector.shuffle` operations.
+struct ToFromElementsToShuffleTreeRewrite final
+    : OpRewritePattern<vector::FromElementsOp> {
+
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::FromElementsOp fromElementsOp,
+                                PatternRewriter &rewriter) const override {
+    VectorType resultType = fromElementsOp.getType();
+    if (resultType.getRank() != 1 || resultType.isScalable())
+      return failure();
+
+    SmallVector<ToElementsOp> toElementsDefs;
+    if (failed(getToElementsDefiningOps(fromElementsOp, toElementsDefs)))
+      return failure();
+
+    // Avoid generating a shuffle tree for trivial `vector.to_elements` ->
+    // `vector.from_elements` forwarding cases that do not require shuffling.
----------------
newling wrote:

If I understand correctly, there is an op folder for this and it will probably already have been removed. My preference here would be to remove this (pass should still work afaict, this should happen automatically between the shuffle and from_elements folders).

https://github.com/llvm/llvm-project/pull/145740


More information about the Mlir-commits mailing list